【问题标题】:find values closest to a list of values in pyspark在 pyspark 中查找最接近值列表的值
【发布时间】:2021-11-20 19:21:22
【问题描述】:

假设有这个 Pyspark 数据框:

x = np.random.randint(1, 100, 1000)
y = np.random.randint(1, 100, 1000)
z = np.random.randint(1, 100, 1000)

df = pd.DataFrame({'x': x, 'y': y, 'z': z})
spark_df = spark.createDataFrame(df)

假设有这个值列表:

lst = [10, 20, 30]

我想检索 spark_df 的所有 3 (=len(lst)) 行,以便 lstspark_df.x 的每个值之间的差异最小。我想将这三个值检索为 spark 数据框。例如:

+---+---+---+
|  x|  y|  z|
+---+---+---+
| 11| 32| 84|
| 22| 12| 38|
| 29| 14| 12|
+---+---+---+

在这种情况下:

  • 11 是最接近 spark_df.x 值的 10
  • 22 是最接近 spark_df.x 值的 20
  • 29 是最接近 spark_df.x 的值 30

如何在 Pyspark 3+ 中实现这个结果?

注意:这只是一个玩具示例,值列表可能有数千个。

【问题讨论】:

  • 如果您询问如何改进现有的工作代码,您应该询问代码审查。此类问题在 Stack Overflow 上是题外话,因为它们是基于意见的。
  • 您在问题中包含了一个(用您自己的话来说是“坏”的)解决方案。因此,如果不是“更好”的解决方案,您还不清楚您要求什么,但您不会费心用具体、客观的术语解释您想要什么。您提供的解决方案有什么问题?您希望您的解决方案实现当前错误的具体指标是什么?当有人说你的问题有问题时,不要嘲笑;而是解决问题。
  • 你不应该仅仅因为我发表了评论就认为我是第一个投反对票的人。除此之外,我的第二条评论与第一条评论的内容相同,只是更详细地解释了help center 为您说的内容。

标签: python dataframe apache-spark pyspark


【解决方案1】:

第一步:将lst的元素与x的值不同的列添加到数据框:

from pyspark.sql import functions as F

diffs = [F.abs(F.col("x") - F.lit(c)).alias(f"diff_{c}") for c in lst]
df_with_diffs = spark_df.select("*", *diffs)
+---+---+---+-------+-------+-------+
|  x|  y|  z|diff_10|diff_20|diff_30|
+---+---+---+-------+-------+-------+
| 15| 34| 20|      5|      5|     15|
| 12| 45| 24|      2|      8|     18|
| 86| 49| 13|     76|     66|     56|
+---+---+---+-------+-------+-------+
only showing top 3 rows

第 2 步: 收集每个 diff 列的最小值并选择相应的行:

mins=df_with_diffs.select(*[F.min(f"diff_{c}") for c in lst]).first()

filter=" or ".join([f"(diff_{c} = {mins[i]})" for i,c in enumerate(lst)])
df_with_diffs.filter(filter).select(spark_df.columns).show()
+---+---+---+
|  x|  y|  z|
+---+---+---+
| 12| 45| 24|
| 22| 28| 58|
| 27| 96| 36|
+---+---+---+

第 2 步(原始答案):对每个新创建的列使用 min_by 以找到差异最小的行。对于lst 的每个值,这将返回一行。那么所有这些行都是unioned

agg_cols = [[F.expr(f"min_by({c}, diff_{val})").alias(c) for c in spark_df.columns] 
  for val in lst]

import functools
result = functools.reduce(lambda a,b: a.union(df_with_diffs.agg(*b)), agg_cols[1:], 
          df_with_diffs.agg(*agg_cols[0]))
result.show()

【讨论】:

  • 不确定这个解决方案是否有效。在玩具示例中,我报告了一个包含三个值的列表,但实际长度可能约为数千。你怎么看?
  • 我认为您必须对lst 的每个条目对整个数据帧执行一次聚合。但让我们拭目以待,看看其他回答者是否找到更好的方法
  • 我更改了第二步,以便明确收集最小值,然后应用普通过滤器。我想这种方法可能更快
  • 从某种意义上说,您提出的解决方案似乎与我的相似(只需在主帖中更新它):我使用 pandas_udf 来检索 mins 值(不创建很多列)并且我使用isin 内置函数避免你写的大过滤函数。但还是不满意
【解决方案2】:

经过我自己的一些实验,我提出了这个解决方案。起初,我想避免使用 pandas_udf 函数,但它看起来优雅、pythonic 和有效。

第 1 步: 创建 pandas_udf 以获得最小值列表:

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import IntegerType, ArrayType
import pyspark.sql.functions as func

bdc_values = sc.broadcast([10, 20, 30])

@pandas_udf(ArrayType(IntegerType()))
def get_x_min(x: pd.Series) -> ArrayType(IntegerType()):
   values = bdc_values.value
   return [x.iloc[(x - v).abs().argmin()] for v in values]

第 2 步: 以这种方式将 pandas_udf 应用到数据框并收集值:

mins = spark_df.agg(get_x_min('x')).first()[0]

第 3 步: 用isin函数过滤spark_df,然后最终去重:

result = spark_df.filter(func.col('x').isin(mins)).dropDuplicates(['x'])

【讨论】:

    猜你喜欢
    • 2020-02-13
    • 2021-11-16
    • 2017-06-25
    • 1970-01-01
    • 1970-01-01
    • 2015-07-26
    • 1970-01-01
    • 2012-11-14
    • 2015-08-03
    相关资源
    最近更新 更多