【问题标题】:Pyspark - Looking to find indexes on the top N largest values in an array columnPyspark - 查找数组列中前 N 个最大值的索引
【发布时间】:2022-01-26 02:50:07
【问题描述】:

我正在寻找替换以下 numpy 命令的功能:

top_n_idx = np.argsort(cosine_sim[idx])[::-1][1:11]

样本数据:

array_col

[0.1,0.5,0.2,0.5,0.9]
[0.1,0.9,0.5,0.2,0.35]

这是我目前的代码:

df.select("array_col", F.slice(F.sort_array(F.col("array_col"), asc=False), 1, 3).alias("top_scores")).show()

array_col               top_scores

[0.1,0.5,0.2,0.55,0.9]  [0.9, 0.55, 0.5]
[0.1,0.9,0.5,0.2,0.35]  [0.9, 0.5, 0.35]

现在,我想做的是在array_col 中找到与“top_scores”列相对应的索引。

array_col               top_scores.       top_score_idx

[0.1,0.5,0.2,0.55,0.9]  [0.9, 0.55, 0.5]  [5, 4, 2]
[0.1,0.9,0.5,0.2,0.35]  [0.9, 0.5, 0.35]  [2, 3, 5]

我最终将使用top_score_idx 来获取另一个数组column 中的对应索引。

【问题讨论】:

    标签: dataframe pyspark


    【解决方案1】:

    对于 Spark 2.4+,使用 array_positiontransform 函数转换 top_scores 数组并在 array_col 列中获取它们从 1 开始的索引。

    df \
    .select("array_col", F.slice(F.sort_array(F.col("array_col"), asc=False), 1, 3).alias("top_scores")) \
    .withColumn("top_score_idx", F.expr("transform(top_scores, v -> array_position(array_col, v))")) \
    .show()
    
    # +--------------------------+----------------+-------------+
    # |array_col                 |top_scores      |top_score_idx|
    # +--------------------------+----------------+-------------+
    # |[0.1, 0.5, 0.2, 0.55, 0.9]|[0.9, 0.55, 0.5]|[5, 4, 2]    |
    # |[0.1, 0.9, 0.5, 0.2, 0.35]|[0.9, 0.5, 0.35]|[2, 3, 5]    |
    # +--------------------------+----------------+-------------+
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-01-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-07-11
      相关资源
      最近更新 更多