【问题标题】:filter spark dataframe with row field that is an array of strings使用作为字符串数组的行字段过滤火花数据帧
【发布时间】:2016-04-22 09:07:18
【问题描述】:

使用 Spark 1.5 和 Scala 2.10.6

我正在尝试通过作为字符串数组的字段“标签”过滤数据框。查找所有带有“private”标签的行。

val report = df.select("*")
  .where(df("tags").contains("private"))

得到:

线程“主”org.apache.spark.sql.AnalysisException 中的异常: 由于数据类型不匹配,无法解析“包含(标签,私有)”: 参数 1 需要字符串类型,但是,'tags' 是数组 类型。;

过滤方法是否更适合?

更新:

数据来自 cassandra 适配器,但显示我正在尝试做的事情并得到上述错误的最小示例是:

  def testData (sc: SparkContext): DataFrame = {
    val stringRDD = sc.parallelize(Seq("""
      { "name": "ed",
        "tags": ["red", "private"]
      }""",
      """{ "name": "fred",
        "tags": ["public", "blue"]
      }""")
    )
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    import sqlContext.implicits._
    sqlContext.read.json(stringRDD)
  }
  def run(sc: SparkContext) {
    val df1 = testData(sc)
    df1.show()
    val report = df1.select("*")
      .where(df1("tags").contains("private"))
    report.show()
  }

更新:标签数组可以是任意长度,“私有”标签可以在任意位置

更新:一种有效的解决方案:UDF

val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")}
val report = df1.filter(filterPriv(df1("tags")))

【问题讨论】:

  • 发布您的数据样本以及您如何创建 df
  • 一种选择是构建 UDF。
  • 好吧,在查看源代码之后(因为Column.contains 的scaladoc 只说“包含其他元素”,这不是很有启发性),我看到Column.contains 构造了一个@ 的实例987654326@ 表示“如果字符串 left 包含字符串 right 则返回 true 的函数”。因此,在这种情况下,df1("tags").contains 似乎无法做我们希望它做的事情。我不知道有什么替代建议。在...expressions 中也有一个ArrayContains,但Column 似乎没有使用它。
  • 确实,将数据改为字符串而不是字符串数组后,发现查询成功了。
  • @DavidMaust,我有一个 UDF 可以工作:val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")}; val report = df1.filter(filterPriv(df1("tags"))) 仍在寻找更好的东西,但至少我没有被阻止。谢谢!

标签: scala apache-spark


【解决方案1】:

我认为如果您使用where(array_contains(...)),它会起作用。这是我的结果:

scala> import org.apache.spark.SparkContext
import org.apache.spark.SparkContext

scala> import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.DataFrame

scala> def testData (sc: SparkContext): DataFrame = {
     |     val stringRDD = sc.parallelize(Seq
     |      ("""{ "name": "ned", "tags": ["blue", "big", "private"] }""",
     |       """{ "name": "albert", "tags": ["private", "lumpy"] }""",
     |       """{ "name": "zed", "tags": ["big", "private", "square"] }""",
     |       """{ "name": "jed", "tags": ["green", "small", "round"] }""",
     |       """{ "name": "ed", "tags": ["red", "private"] }""",
     |       """{ "name": "fred", "tags": ["public", "blue"] }"""))
     |     val sqlContext = new org.apache.spark.sql.SQLContext(sc)
     |     import sqlContext.implicits._
     |     sqlContext.read.json(stringRDD)
     |   }
testData: (sc: org.apache.spark.SparkContext)org.apache.spark.sql.DataFrame

scala>   
     | val df = testData (sc)
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> val report = df.select ("*").where (array_contains (df("tags"), "private"))
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> report.show
+------+--------------------+
|  name|                tags|
+------+--------------------+
|   ned|[blue, big, private]|
|albert|    [private, lumpy]|
|   zed|[big, private, sq...|
|    ed|      [red, private]|
+------+--------------------+

请注意,如果你写where(array_contains(df("tags"), "private")),它可以工作,但如果你写where(df("tags").array_contains("private"))(更直接地类似于你最初写的)它会失败,array_contains is not a member of org.apache.spark.sql.Column。查看Column 的源代码,我看到有一些东西要处理contains(为此构造一个Contains 实例)但不是array_contains。也许这是一个疏忽。

【讨论】:

  • .select("*") 不需要 => df.where(...) ...
  • 需要import org.apache.spark.sql.functions.array_contains才能使用此方法。
【解决方案2】:

您可以使用 ordinal 来引用 json 数组,例如在你的情况下df("tags")(0)。这是一个工作示例

scala> val stringRDD = sc.parallelize(Seq("""
     |       { "name": "ed",
     |         "tags": ["private"]
     |       }""",
     |       """{ "name": "fred",
     |         "tags": ["public"]
     |       }""")
     |     )
stringRDD: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[87] at parallelize at <console>:22

scala> import sqlContext.implicits._
import sqlContext.implicits._

scala> sqlContext.read.json(stringRDD)
res28: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> val df=sqlContext.read.json(stringRDD)
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> df.columns
res29: Array[String] = Array(name, tags)

scala> df.dtypes
res30: Array[(String, String)] = Array((name,StringType), (tags,ArrayType(StringType,true)))

scala> val report = df.select("*").where(df("tags")(0).contains("private"))
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]

scala> report.show
+----+-------------+
|name|         tags|
+----+-------------+
|  ed|List(private)|
+----+-------------+

【讨论】:

  • 谢谢。如果 pos 是固定的,但它不是。我应该让测试数据复杂一点,数组中可以有任意数量的标签,位置是任意的。
  • @navicore 那么你的代码应该可以工作。检查我的更新
  • 有趣,我错过了一些东西,看起来就像我正在做的事情,但得到了错误。现在仔细检查 spark 版本...
  • @navicore 这是 1.5.4
  • 谢谢。我一定在某处交叉手。我尝试了 1.5.1 和 1.6,val report = df.select("*").where(df("tags").contains("private")) 在原始帖子中给了我这个错误。挖……
猜你喜欢
  • 2019-01-14
  • 2018-02-18
  • 1970-01-01
  • 2020-08-09
  • 2018-10-27
  • 2016-05-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多