【问题标题】:Scala spark, input dataframe, return columns where all values equal to 1Scala spark,输入数据框,返回所有值等于 1 的列
【发布时间】:2020-02-12 13:46:55
【问题描述】:

给定一个数据框,假设它包含 4 列和 3 行。我想编写一个函数来返回该列中所有值都等于 1 的列。

这是一个 Scala 代码。我想使用一些火花转换来转换或过滤数据框输入。这个过滤器应该在一个函数中实现。

case class Grade(c1: Integral, c2: Integral, c3: Integral, c4: Integral)
val example = Seq(
      Grade(1,3,1,1),
      Grade(1,1,null,1),
      Grade(1,10,2,1)
    )

    val dfInput = spark.createDataFrame(example)

在我调用函数 filterColumns() 之后

val dfOutput = dfInput.filterColumns()

它应该返回 3 行 2 列的数据框,值全为 1。

【问题讨论】:

    标签: scala dataframe apache-spark filter apache-spark-sql


    【解决方案1】:

    使用Dataset[Grade]的更易读的方法

    import org.apache.spark.sql.functions.col
    import scala.collection.mutable
    import org.apache.spark.sql.Column
    
    val tmp = dfInput.map(grade => grade.dropWhenNotEqualsTo(1))
    val rowsCount = dfInput.count()
    
    val colsToRetain = mutable.Set[Column]()
     for (column <- tmp.columns) {
       val withoutNullsCount = tmp.select(column).na.drop().count()
       if (rowsCount == withoutNullsCount) colsToRetain += col(column)
    }
    
    dfInput.select(colsToRetain.toArray:_*).show()
    
    +---+---+
    | c4| c1|
    +---+---+
    |  1|  1|
    |  1|  1|
    |  1|  1|
    +---+---+
    

    还有案例对象

    case class Grade(c1: Integer, c2: Integer, c3: Integer, c4: Integer) {
      def dropWhenNotEqualsTo(n: Integer): Grade = {
        Grade(nullOrValue(c1, n), nullOrValue(c2, n), nullOrValue(c3, n), nullOrValue(c4, n))
      }
      def nullOrValue(c: Integer, n: Integer) = if (c == n) c else null
    }
    
    1. grade.dropWhenNotEqualsTo(1) -> 返回一个新的 Grade,其中不满足条件的值被替换为空值
    +---+----+----+---+
    | c1|  c2|  c3| c4|
    +---+----+----+---+
    |  1|null|   1|  1|
    |  1|   1|null|  1|
    |  1|null|null|  1|
    +---+----+----+---+
    
    1. (column &lt;- tmp.columns) -> 遍历列

    2. tmp.select(column).na.drop() -> 删除带有空值的行 例如 c2 这将返回

    +---+
    | c2|
    +---+
    |  1|
    +---+
    
    1. if (rowsCount == withoutNullsCount) colsToRetain += col(column) -> 如果列包含空值,则删除它

    【讨论】:

    • 你能告诉我我需要导入哪些包吗?而且,colsToRetain += col(column) 有一些错误,请您解释一下“col()”和“+=”,它在我的计算机中显示错误。
    • import org.apache.spark.sql.functions.col, import scala.collection.mutable, import org.apache.spark.sql.Column 应该够了
    • 谢谢。地图功能需要编码器吗?那是什么意思?你能改变输出数据框的顺序与输入相同吗?说 c1,c4,而不是 c4, c1
    • 关于编码器 - jaceklaskowski.gitbooks.io/mastering-spark-sql/… 它们可作为来自 import sparkSession.implicits._ 的隐式转换使用以保持顺序使用 List 而不是 Set
    【解决方案2】:

    我会尝试在没有nulls 的情况下准备要处理的数据集。如果列数很少,这种简单的迭代方法可能会很好用(不要忘记在 import spark.implicits._ 之前导入 spark 隐式):

    val example = spark.sparkContext.parallelize(Seq(
        Grade(1,3,1,1),
        Grade(1,1,0,1),
        Grade(1,10,2,1)
    )).toDS().cache()
    
    def allOnes(colName: String, ds: Dataset[Grade]): Boolean = {
        val row = ds.select(colName).distinct().collect()
        if (row.length == 1 && row.head.getInt(0) == 1) true
        else false
    }
    
    val resultColumns = example.columns.filter(col => allOnes(col, example))
    example.selectExpr(resultColumns: _*).show()
    

    结果是:

    +---+---+
    | c1| c4|
    +---+---+
    |  1|  1|
    |  1|  1|
    |  1|  1|
    +---+---+
    

    如果nulls 是不可避免的,请使用无类型数据集(又名数据框):

    val schema = StructType(Seq(
        StructField("c1", IntegerType, nullable = true),
        StructField("c2", IntegerType, nullable = true),
        StructField("c3", IntegerType, nullable = true),
        StructField("c4", IntegerType, nullable = true)
    ))
    
    val example = spark.sparkContext.parallelize(Seq(
        Row(1,3,1,1),
        Row(1,1,null,1),
        Row(1,10,2,1)
    ))
    
    val dfInput = spark.createDataFrame(example, schema).cache()
    
    def allOnes(colName: String, df: DataFrame): Boolean = {
        val row = df.select(colName).distinct().collect()
        if (row.length == 1 && row.head.getInt(0) == 1) true
        else false
    }
    
    val resultColumns= dfInput.columns.filter(col => allOnes(col, dfInput))
    dfInput.selectExpr(resultColumns: _*).show()
    

    【讨论】:

      【解决方案3】:

      其中一个选项是 rdd 上的reduce

        import spark.implicits._
      
        val df= Seq(("1","A","3","4"),("1","2","?","4"),("1","2","3","4")).toDF()
        df.show()
      
        val first = df.first()
        val size = first.length
        val diffStr = "#"
        val targetStr = "1"
      
         def rowToArray(row: Row): Array[String] = {
           val arr = new Array[String](row.length)
           for (i <- 0 to row.length-1){
             arr(i) = row.getString(i)
           }
           arr
         }
      
        def compareArrays(a1: Array[String], a2: Array[String]): Array[String] = {
          val arr = new Array[String](a1.length)
          for (i <- 0 to a1.length-1){
            arr(i) = if (a1(i).equals(a2(i)) && a1(i).equals(targetStr)) a1(i) else diffStr
          }
          arr
        }
      
        val diff = df.rdd
          .map(rowToArray)
          .reduce(compareArrays)
      
        val cols = (df.columns zip diff).filter(!_._2.equals(diffStr)).map(s=>df(s._1))
      
        df.select(cols:_*).show()
      
          +---+---+---+---+
          | _1| _2| _3| _4|
          +---+---+---+---+
          |  1|  A|  3|  4|
          |  1|  2|  ?|  4|
          |  1|  2|  3|  4|
          +---+---+---+---+
      
          +---+
          | _1|
          +---+
          |  1|
          |  1|
          |  1|
          +---+
      

      【讨论】:

      • 你能解释一下最后两行吗?谢谢。
      • zip 创建元组列表,其中 ._1 元素是列名,._2 来自 diff(_1,1) (_2,#) (_3,#) (_4,#),然后我过滤掉带有“#”的记录并在映射步骤中仅返回列名。最后,我使用 :_* 将这些列名放入 select as varargs
      猜你喜欢
      • 1970-01-01
      • 2020-08-25
      • 1970-01-01
      • 1970-01-01
      • 2020-09-08
      • 1970-01-01
      • 2017-08-08
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多