【问题标题】:How to efficiently select dataframe columns containing a certain value in Spark?如何在 Spark 中有效地选择包含某个值的数据框列?
【发布时间】:2020-11-30 17:12:04
【问题描述】:

假设您在 spark(字符串类型)中有一个数据框,并且您想要删除任何包含“foo”的列。在下面的示例数据框中,您将删除列“c2”和“c3”,但保留“c1”。但是,我希望该解决方案能够推广到大量的列和行。

    +-------------------+
    |   c1|   c2|     c3|
    +-------------------+
    | this|  foo|  hello|
    | that|  bar|  world|
    |other|  baz| foobar|
    +-------------------+

我的解决方案是扫描数据框中的每一列,然后使用数据框 API 和内置函数聚合结果。 因此,可以像这样扫描每一列(我是 scala 新手,请原谅语法错误):

df = df.select(df.columns.map(c => col(c).like("foo"))

从逻辑上讲,我会有一个像这样的中间数据框:

    +--------------------+
    |    c1|    c2|    c3|
    +--------------------+
    | false|  true| false|
    | false| false| false|
    | false| false|  true|
    +--------------------+

然后将其聚合成一行以读取需要删除的列。

exprs = df.columns.map( c => max(c).alias(c))

drop = df.agg(exprs.head, exprs.tail: _*)

    +--------------------+
    |    c1|    c2|    c3|
    +--------------------+
    | false|  true|  true|
    +--------------------+

现在可以删除任何包含 true 的列。

我的问题是:有没有更好的方法来做到这一点,性能明智?在这种情况下,一旦找到“foo”,spark 是否会停止扫描列?数据的存储方式是否重要(镶木地板有帮助吗?)。

谢谢,我是新来的,所以请告诉我如何改进这个问题。

【问题讨论】:

  • 除了不实际工作,我不知道你怎么能把它短路。有兴趣看看对方怎么说。

标签: scala apache-spark apache-spark-sql


【解决方案1】:

根据您的数据,例如,如果您有很多 foo 值,则下面的代码可能会更有效地执行:

val colsToDrop = df.columns.filter{ c =>
  !df.where(col(c).like("foo")).limit(1).isEmpty
}

df.drop(colsToDrop: _*)

更新:删除了多余的.limit(1)

val colsToDrop = df.columns.filter{ c =>
  !df.where(col(c).like("foo")).isEmpty
}

df.drop(colsToDrop: _*)

【讨论】:

  • 限制结果集还是搜索?
  • 我认为.limit(1) 是不必要的。
  • @LeoC 为什么会这样?您认为限制是如何起作用的?
  • myDF.limit(1) 如果myDF 非空,则为 1 行 DataFrame,否则为 0 行。因此myDF.limit(1).isEmpty 对验证myDF 是否为空有效。对我来说似乎是多余的。
  • 这是一个很好的观点@LeoC。这是 isEmpty 的一个实现:def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => plan.executeCollect().head.getLong(0) == 0 }
【解决方案2】:

遵循您的逻辑的答案(正确计算),但我认为另一个答案更好,对于后代和您提高 Scala 的能力更是如此。我不确定另一个答案实际上是否有效,但这也不是。不确定镶木地板是否有帮助,很难衡量。

另一种选择是在驱动程序上编写一个循环并访问每个 由于柱状、统计数据和 向下推。

import org.apache.spark.sql.functions._
def myUDF = udf((cols: Seq[String], cmp: String) => cols.map(code => if (code == cmp) true else false ))

val df = sc.parallelize(Seq(
   ("foo", "abc", "sss"),
   ("bar", "fff", "sss"),
   ("foo", "foo", "ddd"),
   ("bar", "ddd", "ddd")
   )).toDF("a", "b", "c")

val res = df.select($"*", array(df.columns.map(col): _*).as("colN"))
            .withColumn( "colres", myUDF( col("colN") , lit("foo") )  )

res.show()
res.printSchema()
val n = 3
val res2 = res.select( (0 until n).map(i => col("colres")(i).alias(s"c${i+1}")): _*)
res2.show(false)

val exprs = res2.columns.map( c => max(c).alias(c))
val drop = res2.agg(exprs.head, exprs.tail: _*)
drop.show(false)

【讨论】:

    猜你喜欢
    • 2021-11-27
    • 1970-01-01
    • 2016-11-06
    • 1970-01-01
    • 1970-01-01
    • 2021-08-20
    • 2021-10-03
    • 1970-01-01
    相关资源
    最近更新 更多