使用 Pandas UDF (Spark >= 2.3)
import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt
data = [
{'account': '1', 'order_date': '11/18/20', 'amount': -34.99},
{'account': '1', 'order_date': '10/28/20', 'amount': -4.99},
{'account': '1', 'order_date': '9/11/20', 'amount': 4.99},
{'account': '1', 'order_date': '9/2/20', 'amount': 9.98}]
# For simiplicity, creating a new column "balance" with 0.0
input_df = self._spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
input_df.show()
schema = input_df.schema
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_running_total(input: pd.DataFrame):
import os
# To fix a bug in pyarrow newer versions
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1"
previous_balance = None
input['order_date'] = pd.to_datetime(input['order_date'])
df = input.sort_values(by=['order_date'], ascending=False)
df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
for i, row in df.iterrows():
current_amount = row['amount']
if i == 0:
running_total = current_amount
else:
if current_amount > 0:
running_total = current_amount + previous_balance
else:
if previous_balance > 0:
previous_balance = 0
running_total = previous_balance + current_amount
df._set_value(i, 'balance', running_total)
previous_balance = running_total
return df
input_df.groupby('account').apply(_get_running_total).show()
已更新,在 3.2 上运行良好
import pandas as pd
import pyspark.sql.functions as f
from pyspark.sql.functions import pandas_udf, PandasUDFType
import datetime as dt
from pyspark import SparkContext
from pyspark.sql import SparkSession
_spark = SparkSession.builder.appName("SparkByExamples.com").getOrCreate()
data = [
{'account': '1', 'order_date': '11/18/20', 'amount': -34.99},
{'account': '1', 'order_date': '10/28/20', 'amount': -4.99},
{'account': '1', 'order_date': '9/11/20', 'amount': 4.99},
{'account': '1', 'order_date': '9/2/20', 'amount': 9.98}]
# For simiplicity, creating a new column "balance" with 0.0
input_df = _spark.createDataFrame(data).withColumn('balance', f.lit(0.0))
input_df.show()
schema = input_df.schema
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_running_total(input: pd.DataFrame):
import os
# To fix a bug in pyarrow newer versions
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = "1"
previous_balance = None
input['order_date'] = pd.to_datetime(input['order_date'])
df = input.sort_values(by=['order_date'], ascending=False)
df['order_date'] = df['order_date'].apply(lambda x: dt.datetime.strftime(x, '%m/%d/%Y'))
for i, row in df.iterrows():
current_amount = row['amount']
if i == 0:
running_total = current_amount
else:
if current_amount > 0:
running_total = current_amount + previous_balance
else:
if previous_balance > 0:
previous_balance = 0
running_total = previous_balance + current_amount
df._set_value(i, 'balance', running_total)
previous_balance = running_total
return df
input_df.groupby('account').apply(_get_running_total).show()