【发布时间】:2019-03-09 02:09:39
【问题描述】:
我有一个如下所示的 pyspark 数据框。
+---+-------+--------+
|age|balance|duration|
+---+-------+--------+
| 2| 2143| 261|
| 44| 29| 151|
| 33| 2| 76|
| 50| 1506| 92|
| 33| 1| 198|
| 35| 231| 139|
| 28| 447| 217|
| 2| 2| 380|
| 58| 121| 50|
| 43| 693| 55|
| 41| 270| 222|
| 50| 390| 137|
| 53| 6| 517|
| 58| 71| 71|
| 57| 162| 174|
| 40| 229| 353|
| 45| 13| 98|
| 57| 52| 38|
| 3| 0| 219|
| 4| 0| 54|
+---+-------+--------+
我的预期输出应该是这样的,
+---+-------+--------+-------+-----------+------------+
|age|balance|duration|age_out|balance_out|duration_out|
+---+-------+--------+-------+-----------+------------+
| 2| 2143| 261| 1| 1| 0|
| 44| 29| 151| 0| 0| 0|
| 33| 2| 76| 0| 0| 0|
| 50| 1506| 92| 0| 1| 0|
| 33| 1| 198| 0| 0| 0|
| 35| 231| 139| 0| 0| 0|
| 28| 447| 217| 0| 0| 0|
| 2| 2| 380| 1| 0| 0|
| 58| 121| 50| 0| 0| 0|
| 43| 693| 55| 0| 0| 0|
| 41| 270| 222| 0| 0| 0|
| 50| 390| 137| 0| 0| 0|
| 53| 6| 517| 0| 0| 1|
| 58| 71| 71| 0| 0| 0|
| 57| 162| 174| 0| 0| 0|
| 40| 229| 353| 0| 0| 0|
| 45| 13| 98| 0| 0| 0|
| 57| 52| 38| 0| 0| 0|
| 3| 0| 219| 1| 0| 0|
| 4| 0| 54| 0| 0| 0|
+---+-------+--------+-------+-----------+------------+
这里我的目标是使用我在下面的 python 代码中描述的四分位数方法来识别数据集中的异常记录。如果我们发现任何异常记录,那么我们需要将它们标记为 1,否则标记为 0。
我可以通过下面的代码使用 python 做同样的事情。
import numpy as np
def outliers_iqr(ys):
quartile_1, quartile_3 = np.percentile(ys, [25, 75])
iqr = quartile_3 - quartile_1
lower_bound = quartile_1 - (iqr * 1.5)
upper_bound = quartile_3 + (iqr * 1.5)
ser = np.zeros(len(ys))
pos =np.where((ys > upper_bound) | (ys < lower_bound))[0]
ser[pos]=1
return(ser)
但我想在 pyspark 中做同样的事情。有人可以帮助我吗?
我的 pyspark 代码:
def outliers_iqr(ys):
quartile_1, quartile_3 = np.percentile(ys, [25, 75])
iqr = quartile_3 - quartile_1
lower_bound = quartile_1 - (iqr * 1.5)
upper_bound = quartile_3 + (iqr * 1.5)
ser = np.zeros(len(ys))
pos =np.where((ys > upper_bound) | (ys < lower_bound))[0]
ser[pos]=1
return(float(ser))
outliers_iqr_udf = udf(outliers_iqr, FloatType())
DF.withColumn('age_out', outliers_iqr_udf(DF.select('age').collect())).show()
【问题讨论】:
标签: python-3.x apache-spark pyspark