Pandas中的性能优化

Database and Ruby, Python, History


最近做了一次 Pandas 的性能优化,对比下来,在小规模的数据集上,Pandas 的性能会比 SQL 更好。当数据量继续上涨,SQL 会更好些。如果数据到了几百 GB,甚至 TB 级别,目前看只能够用 Spark 了。

Pandas

使用 Pandas 的时候,建议将数据集当做一个整体来操作,这样的性能最好。比如,需要两个数据集的关联查询再求和,可以用apply,但是性能比较差。下面的例子中,通过apply消耗了 8.84 秒。背后的原因是有一个迭代器,逐条循环记录,所以性能很慢。

import pandas as pd
import random

data1 = {'user_id': [i for i in range(10000)],
         'name': ['Name_' + str(i) for i in range(1, 10001)]}
df1 = pd.DataFrame(data1)

data2 = {'user_id': [random.randint(1, 10000) for _ in range(100000)],
         'month': ['2023-{:02d}'.format(random.randint(1, 12)) for _ in range(100000)],
         'qty': [random.randint(1, 100) for _ in range(100000)]}
df2 = pd.DataFrame(data2)

def agg_sum(x):
    filtered_df = df2[df2['user_id'] == x['user_id']]
    return filtered_df[['qty']].sum()

# 8.84 seconds
df1[['qty']] = df1.apply(lambda x: agg_sum(x), axis=1, result_type='expand')

如果将Pandas作为整体,只需要 21.1ms。

# 21.1 ms
agg_df = pd.merge(df1, df2, on='user_id').groupby(['user_id'])['qty'].sum().reset_index()

Out[2]:
      user_id  qty
0           1  286
1           2   78
2           3  426
3           4  575
4           5  267
...       ...  ...
9994     9995  380
9995     9996  311
9996     9997  364
9997     9998  430
9998     9999  350

[9999 rows x 2 columns]

SQL

如果用 SQL,写起来就简单直观,速度也很快。

SELECT df1.user_id, sum(df2.qty) qty
FROM df1
JOIN df2 ON df1.user_id = df2.user_id

Spark

当数据量上升到几十 GB,几百 GB 的时候,就需要考虑用 Spark 了。下面是一个例子,因为数据量小,体现不出什么优势,用时大概 5 秒。

from pyspark.sql import SparkSession
import time

spark = SparkSession.builder.appName('Postgres').master('spark://localhost:7077').config("spark.driver.extraClassPath", "postgresql-42.6.0.jar").getOrCreate()

jdbc_url = 'jdbc:postgresql://localhost:5432/database'

properties = {
    "user": "username",
    "password": "password",
    "driver": "org.postgresql.Driver"
}

start_time = time.time()

users = spark.read.jdbc(url=jdbc_url, table="users", properties=properties)
orders = spark.read.jdbc(url=jdbc_url, table="orders", properties=properties)


users.createOrReplaceTempView("df1")
orders.createOrReplaceTempView("df2")

query = """
    SELECT df1.user_id, sum(df2.qty) qty
    FROM df1
    JOIN df2 ON df1.user_id = df2.user_id
"""

result = spark.sql(query).collect()

print("--- %s seconds ---" % (time.time() - start_time))

spark.stop()