【问题标题】:Splitting rows of a dataset depending on a column values根据列值拆分数据集的行
【发布时间】:2021-04-14 05:02:37
【问题描述】:

我正在使用Spark 3.1.1JAVA 8,我试图根据其中一个数值列的值(大于或小于阈值)拆分dataset<Row>,只有在某些字符串时才可以拆分行的列值是相同的:我正在尝试这样的事情:

                Iterator<Row> iter2 = partition.toLocalIterator();                   
                while (iter2.hasNext()) {
                    Row item = iter2.next();
                      //getColVal is a function that gets the value given a column
                    String numValue = getColVal(item, dim);
                    if (Integer.parseInt(numValue) < threshold)    
                        pl.add(item);  
                    else
                        pr.add(item);

但是如何在拆分之前检查相关行的其他列值(字符串)是否相同以便执行拆分?

PS:我尝试在拆分之前对列进行分组:

Dataset<Row> newDataset=oldDataset.groupBy("col1","col4").agg(col("col1"));

但它不起作用

感谢您的帮助

编辑:

我要拆分的样本数据集是:

abc,9,40,A
abc,7,50,A
cde,4,20,B
cde,3,25,B

如果阈值为30,那么第一行和最后两行将形成两个数据集,因为它们的第一列和第四列是相同的;否则无法拆分。

编辑:结果输出将是

    abc,9,40,A
    abc,7,50,A


    cde,4,20,B
    cde,3,25,B

【问题讨论】:

  • 如果我理解正确,如果 col3 的阈值 > 30,您是否希望获得 2 个数据集。您还可以添加结果数据集的样子吗?另外,如果你的阈值没有达到,你能提供一个示例输出吗?
  • 如果 col3 的阈值 > 30 则正好是 2 个数据集,如果不满足阈值则返回原始数据集
  • 检查我对示例输出的编辑

标签: apache-spark


【解决方案1】:

我主要使用pyspark,但你可以适应你的环境

## could add some conditional logic or just always output 2 data frames where 
##   one would be empty

print("pdf - two dataframe")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf) 

sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc|   9|  40|   A|
# | abc|   7|  50|   A|
# | cde|   4|  20|   B|
# | cde|   3|  25|   B|
# +----+----+----+----+




## filter
pl = sdf.filter('col3 <= 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
pr = sdf.filter('col3 > 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))
print("pl")
pl.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | cde|   B|    7|
# +----+----+-----+


print("pr")
pr.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | abc|   A|   16|
# +----+----+-----+


print("pdf - one dataframe")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[11,29,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf) 
sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc|   9|  11|   A|
# | abc|   7|  29|   A|
# | cde|   4|  20|   B|
# | cde|   3|  25|   B|
# +----+----+----+----+



pl = sdf.filter('col3 <= 30')\
        .groupBy("col1","col4").agg( F.sum('col2').alias('sumC2') )
pr = sdf.filter('col3 > 30')\
        .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))

print("pl")
pl.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# | abc|   A|   16|
# | cde|   B|    7|
# +----+----+-----+

print("pr")
pr.show()
# +----+----+-----+
# |col1|col4|sumC2|
# +----+----+-----+
# +----+----+-----+

通过动态均值过滤

print("pdf - filter by mean")
## create pandas dataframe
pdf = pd.DataFrame({'col1':['abc','abc','cde','cde'],'col2':[9,7,4,3],'col3':[40,50,20,25],'col4':['A','A','B','B']})
print( pdf )

## move it to spark
print("sdf")
sdf = spark.createDataFrame(pdf) 
sdf.show()
# +----+----+----+----+
# |col1|col2|col3|col4|
# +----+----+----+----+
# | abc|   9|  40|   A|
# | abc|   7|  50|   A|
# | cde|   4|  20|   B|
# | cde|   3|  25|   B|
# +----+----+----+----+

w = Window.partitionBy("col1").orderBy("col2")
## add another column, the mean of col2 partitioned by col1
sdf = sdf.withColumn('mean_c2', F.mean('col2').over(w))

## filter by the dynamic mean
pr = sdf.filter('col2 > mean_c2')
pr.show()

# +----+----+----+----+-------+
# |col1|col2|col3|col4|mean_c2|
# +----+----+----+----+-------+
# | cde|   4|  20|   B|    3.5|
# | abc|   9|  40|   A|    8.0|
# +----+----+----+----+-------+

【讨论】:

  • 如果可能的话,你能用JAVA解释一下pl = sdf.filter('col3 &lt;= 30')\ .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2')) pr = sdf.filter('col3 &gt; 30')\ .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))这些行吗
  • 当然,pl = sdf.filter('col3 &lt;= 30')\ .groupBy("col1","col4").agg(F.sum('col2').alias('sumC2'))pl 设置为 col3 值小于 30 的位置。按列分组 col1,col4 并将 column2 和别名添加到 sumC2 如果您只想拆分为2个数据框,那么你只需要pl = sdf.filter('col3 &lt;= 30')pr = sdf.filter('col3 &gt; 30') .agg 是如果需要进一步聚合后过滤
  • 找到了这个codota.com/code/java/methods/org.apache.spark.sql.Dataset/…`JavaDataFrameSuite.testExecution()`@Testpublic void testExecution() {`Dataset df = spark.table("testData").filter("key = 1") ;` ` Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));` }
  • OK for select 将充当过滤器,类似于.select("col3 &gt; threshold")collectAsList().get(0).get(0)); 将用于什么?
  • 我正在尝试这样的事情:Dataset&lt;Row&gt; filteredDF = dF.filter("col2 &gt; 60").groupBy("col0","col1","col3").agg(); 但不知道该放什么 agg()
猜你喜欢
  • 1970-01-01
  • 2017-07-27
  • 2021-01-17
  • 1970-01-01
  • 2019-05-06
  • 1970-01-01
  • 1970-01-01
  • 2021-03-28
相关资源
最近更新 更多