最近做了一次 Pandas 的性能优化,对比下来,在小规模的数据集上,Pandas 的性能会比 SQL 更好。当数据量继续上涨,SQL 会更好些。如果数据到了几百 GB,甚至 TB 级别,目前看只能够用 Spark 了。
Pandas
使用 Pandas 的时候,建议将数据集当做一个整体来操作,这样的性能最好。比如,需要两个数据集的关联查询再求和,可以用apply
,但是性能比较差。下面的例子中,通过apply
消耗了 8.84 秒。背后的原因是有一个迭代器,逐条循环记录,所以性能很慢。
import pandas as pd
import random
data1 = {'user_id': [i for i in range(10000)],
'name': ['Name_' + str(i) for i in range(1, 10001)]}
df1 = pd.DataFrame(data1)
data2 = {'user_id': [random.randint(1, 10000) for _ in range(100000)],
'month': ['2023-{:02d}'.format(random.randint(1, 12)) for _ in range(100000)],
'qty': [random.randint(1, 100) for _ in range(100000)]}
df2 = pd.DataFrame(data2)
def agg_sum(x):
filtered_df = df2[df2['user_id'] == x['user_id']]
return filtered_df[['qty']].sum()
# 8.84 seconds
df1[['qty']] = df1.apply(lambda x: agg_sum(x), axis=1, result_type='expand')
如果将Pandas
作为整体,只需要 21.1ms。
# 21.1 ms
agg_df = pd.merge(df1, df2, on='user_id').groupby(['user_id'])['qty'].sum().reset_index()
Out[2]:
user_id qty
0 1 286
1 2 78
2 3 426
3 4 575
4 5 267
... ... ...
9994 9995 380
9995 9996 311
9996 9997 364
9997 9998 430
9998 9999 350
[9999 rows x 2 columns]
SQL
如果用 SQL,写起来就简单直观,速度也很快。
SELECT df1.user_id, sum(df2.qty) qty
FROM df1
JOIN df2 ON df1.user_id = df2.user_id
Spark
当数据量上升到几十 GB,几百 GB 的时候,就需要考虑用 Spark 了。下面是一个例子,因为数据量小,体现不出什么优势,用时大概 5 秒。
from pyspark.sql import SparkSession
import time
spark = SparkSession.builder.appName('Postgres').master('spark://localhost:7077').config("spark.driver.extraClassPath", "postgresql-42.6.0.jar").getOrCreate()
jdbc_url = 'jdbc:postgresql://localhost:5432/database'
properties = {
"user": "username",
"password": "password",
"driver": "org.postgresql.Driver"
}
start_time = time.time()
users = spark.read.jdbc(url=jdbc_url, table="users", properties=properties)
orders = spark.read.jdbc(url=jdbc_url, table="orders", properties=properties)
users.createOrReplaceTempView("df1")
orders.createOrReplaceTempView("df2")
query = """
SELECT df1.user_id, sum(df2.qty) qty
FROM df1
JOIN df2 ON df1.user_id = df2.user_id
"""
result = spark.sql(query).collect()
print("--- %s seconds ---" % (time.time() - start_time))
spark.stop()