【问题标题】:generate disjoint sets of Spark DataFrame生成不相交的 Spark DataFrame 集
【发布时间】:2023-03-31 22:26:02
【问题描述】:

每当多列之一具有等效值时,我想对 Spark DataFrame 进行分组。 例如,对于以下 df:

  val df = Seq(
    ("a1", "b1", "c1"),
    ("a1", "b2", "c2"),
    ("a3", "b2", "c3"),
    ("a4", "b4", "c3"),
    ("a5", "b5", "c5")
  ).toDF("a", "b", "c")

只要 abc 列的值匹配,我想要一个分组。在示例中,第一行的 DataFrame 字段 a 与第二行匹配。第二个字段b 匹配第三个字段,第三个字段c 匹配第四行,所以它们都在同一个集合中(想想union-find)。第五行是单例集。

val grouped = Seq(
  ("a1", "b1", "c1", "1"),
  ("a1", "b2", "c2", "1"),
  ("a3", "b2", "c3", "1"),
  ("a4", "b4", "c3", "1"),
  ("a5", "b5", "c5", "2")
).toDF("a", "b", "c", "group")

我添加了group 列作为对可能的不相交集结果的直觉。

【问题讨论】:

  • 对您的预期输出有点困惑。为什么 2 分配给最后一行?你能解释一下这个逻辑吗?
  • @Raghu 确实如此。我更新了我的问题
  • graphx 它将是……
  • 您有机会查看答案吗?很想知道你的想法
  • 我还需要给他们看看,谢谢!我已经使用 Spark Graphx 来生成联合查找数据结构。我可能会回到更临时的解决方案,但结果看起来很公平......

标签: apache-spark apache-spark-sql


【解决方案1】:

试试这个,让我知道。基本上,我们用它们的出现计数替换这些值,并在所有为 1 的情况下进行过滤。警告:由于使用了 collect(),因此计算量很大

import pyspark.sql.functions as F
from pyspark.sql.window import Window
#Test data
tst = sqlContext.createDataFrame([('a1','b1','c1','d1'),('a1','b2','c2','d2'),('a3','b2','c3','d6'),('a4','b4','c3','d7'),('a5','b5','c5','d7'),('a6','b6','c6','d27'),('a9','b88','c54','d71')],schema=['a','b','c','d'])
#%% create a unique id for the records
tst_id = tst.withColumn("id",F.monotonically_increasing_id())
#%% arrays to store the counts for each value. This is computationaly intensive since we bring all data to driver using collect()
val_arr=[]
repl_arr=[]
for x in tst.columns:
    tst_agg = tst.groupby(x).count().collect()    
    val_arr=val_arr+([y[x] for y in tst_agg])
    repl_arr=repl_arr+([y['count'] for y in tst_agg])   

#%% replace the values with their counts,string conversion is based on data type
df_repl = tst_id.replace(map(str,val_arr),map(str,repl_arr))
#%% Note : sum() is inbuilt python function. This is to check if all values are one
df_repl_sum = df_repl.withColumn("sum",sum([F.col(x) for x in tst.columns]))
#%%extract values that has one occurance for all column values
df_select = df_repl_sum.where(F.col('sum')==len(tst.columns))
#%% join with the main data to see the disjoint values
df_res = df_select.select('id','sum').join(tst_id,on='id',how='left')

结果:

+------------+---+---+---+---+---+
|          id|sum|  a|  b|  c|  d|
+------------+---+---+---+---+---+
|146028888064|4.0| a6| b6| c6|d27|
|171798691840|4.0| a9|b88|c54|d71|
+------------+---+---+---+---+---+

如果您需要一个组列,则将最终连接更改为右侧,并根据总和值将列作为组:F.when(F.col('sum')==len(tst.columns),1).otherwise(0)

【讨论】:

    【解决方案2】:

    我只是想将其添加为新答案,因为我不太确定多维数据集在 collect() 上的性能。但我觉得这比我之前的回答要好。试试这个。

    import pyspark.sql.functions as F
    from pyspark.sql.window import Window
    #Test data
    tst = sqlContext.createDataFrame([('a1','b1','c1','d1'),('a1','b2','c2','d2'),('a3','b2','c3','d6'),('a4','b4','c3','d7'),('a5','b5','c5','d7'),('a6','b6','c6','d27'),('a9','b88','c54','d71')],schema=['a','b','c','d'])
    #%% aggregate and cube the columns and count
    
    tst_res1 = tst.cube('a','b','c','d').count()
    # We need count of individual values in columns. so we count how many nulls are there in column
    tst_nc = tst_res1.withColumn("null_count",sum([F.when(F.col(x).isNull(),1).otherwise(0) for x in tst_res1.columns]))
    # Filter only with 3 null values since we have 4 columns and select values that occur more than once
    tst_flt = tst_nc.filter((F.col('null_count')==len(tst.columns)-1)& (F.col('count')>1))
    # coalesce to get the elements that occur more than once
    tst_coala= tst_flt.withColumn("elements",F.coalesce(*tst.columns))
    # collect the elements that occur more than once in an element. 
    tst_array = (tst_coala.groupby(F.lit(1)).agg(F.collect_list('elements').alias('elements'))).collect()
    #%% convert elements to string, can be skipped for numericals
    elements = map(str,tst_array[0]['elements'])
    #%% introduce the values that occur more than once as an array in main df
    tst_cmp= tst.withColumn("elements_array",F.array(map(F.lit,[x for x in elements])))
    # convert the elements into an array
    tst_cmp = tst_cmp.withColumn("main_array",F.array(*tst.columns))
    #%% find if any of the elements in the row occur more than once in the entire data
    tst_result = tst_cmp.withColumn("flag", F.size(F.array_intersect(F.col('main_array'),F.col('elements_array'))))
    #%% select the disjoint values
    tst_final = tst_result.where('flag=0')
    

    结果:

    +---+---+---+---+----------------+-------------------+----+
    |  a|  b|  c|  d|  elements_array|         main_array|flag|
    +---+---+---+---+----------------+-------------------+----+
    | a6| b6| c6|d27|[b2, c3, a1, d7]|  [a6, b6, c6, d27]|   0|
    | a9|b88|c54|d71|[b2, c3, a1, d7]|[a9, b88, c54, d71]|   0|
    +---+---+---+---+----------------+-------------------+----+
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2016-05-18
      • 1970-01-01
      • 1970-01-01
      • 2023-04-03
      • 1970-01-01
      • 1970-01-01
      • 2015-09-20
      相关资源
      最近更新 更多