【问题标题】:Scala Dataframe get max value of specific rowScala Dataframe 获取特定行的最大值
【发布时间】:2019-08-29 09:49:02
【问题描述】:

给定一个带有索引列(“Z”)的数据框:

val tmp= Seq(("D",0.1,0.3, 0.4), ("E",0.3, 0.1, 0.4), ("F",0.2, 0.2, 0.5)).toDF("Z", "a", "b", "c")

+---+---+---+---+
| Z |  a|  b|  c|
 ---+---+---+---+
| "D"|0.1|0.3|0.4|
| "E"|0.3|0.1|0.4|
| "F"|0.2|0.2|0.5|
+---+---+---+---+

说我对 Z = "D" 的第一行感兴趣:

tmp.filter(col("Z")=== "D")
+---+---+---+---+
| Z |  a|  b|  c|
+---+---+---+---+
|"D"|0.1|0.3|0.4|
+---+---+---+---+

如何在保留索引列的同时获取该 Dataframe 行的最小值和最大值及其对应的列名?

如果我想要前 2 个最大值,则需要输出

+---+---+---
| Z |  b|c  |
+---+---+--+
| D |0.3|0.4|
+---+---+---

如果我想要 min 所需的输出

+---+---+
| Z |  a|
+---+---+
| D |0.1|
+---+---+

我尝试了什么:

// first convert that DF to an array
val tmp = df.collect.map(_.toSeq).flatten
// returns 
tmp: Array[Any] = Array(0.1, 0.3, 0.4) <---dont know why Any is returned


//take top values of array
val n = 1
tmp.zipWithIndex.sortBy(-_._1).take(n).map(_._2)

但出现错误:

   No implicit Ordering defined for Any.

有什么方法可以直接从数据框而不是数组中完成?

【问题讨论】:

  • 能否请您提供有关所需输出的更多详细信息,并让我知道什么是 df Dataframe 和什么是 tmp dataframe?
  • @Nikk 更新 t0 反映所需的输出

标签: scala dataframe apache-spark apache-spark-sql


【解决方案1】:

你可以这样做

tmp
  .where($"a" === 0.1)
  .take(1)
  .map { row =>
      Seq(row.getDouble(0), row.getDouble(1), row.getDouble(2))
  }
  .head
  .sortBy(d => -d)
  .take(2)

或者,如果您有大量字段,您可以将架构和模式匹配行字段与这样的架构数据类型进行对比

import org.apache.spark.sql.types._

val schemaWithIndex = tmp.schema.zipWithIndex

tmp
.where($"a" === 0.1)
.take(1)
.map { row =>
    for {
        tuple <- schemaWithIndex
    } yield {
        val field = tuple._1
        val index = tuple._2
        field.dataType match {
            case DoubleType => row.getDouble(index)
        }
    }
}
.head
.sortBy(d => -d)
.take(2)

也许有更简单的方法来做到这一点。

【讨论】:

  • 如果有更多的字段,你能举个例子吗?
【解决方案2】:

绝对不是最快的方式,而是直接来自数据框

更通用的解决方案:

// somewhere in codebase
import spark.implicits._
import org.apache.spark.sql.functions._

def transform[T, R : Encoder](ds: DataFrame, colsToSelect: Seq[String])(func: Map[String, T] => Map[String, R])
                            (implicit encoder: Encoder[Map[String, R]]): DataFrame = {
    ds.map(row => func(row.getValuesMap(colsToSelect)))
      .toDF()
      .select(explode(col("value")))
      .withColumn("idx", lit(1))
      .groupBy(col("idx")).pivot(col("key")).agg(first(col("value")))
      .drop("idx")
  }

现在是关于使用 Map,其中地图 keyfield name,地图 valuefield value

def fuzzyStuff(values: Map[String, Any]): Map[String, String] = {
  val valueForA = values("a").asInstanceOf[Double]
  //Do whatever you want to do
  // ...
  //use map as a return type where key is a column name and value is whatever yo want to
  Map("x" -> (s"fuzzyA-$valueForA"))
}


def maxN(n: Int)(values: Map[String, Double]): Map[String, Double] = {
 println(values)
 values.toSeq.sorted.reverse.take(n).toMap
}

用法:

val tmp = Seq((0.1,0.3, 0.4), (0.3, 0.1, 0.4), (0.2, 0.2, 0.5)).toDF("a", "b", "c")
val filtered = tmp.filter(col("a") === 0.1)

transform(filtered, colsToSelect = Seq("a", "b", "c"))(maxN(2))
   .show()

+---+---+
|  b|  c|
+---+---+
|0.3|0.4|
+---+---+

transform(filtered, colsToSelect = Seq("a", "b", "c"))(fuzzyStuff)
   .show()

+----------+
|         x|
+----------+
|fuzzyA-0.1|
+----------+

  1. 定义maxmin 函数
  def maxN(values: Map[String, Double], n: Int): Map[String, Double] = {
    values.toSeq.sorted.reverse.take(n).toMap
  }

  def min(values: Map[String, Double]): Map[String, Double] = {
    Map(values.toSeq.min)
  }
  1. 创建数据集
val tmp= Seq((0.1,0.3, 0.4), (0.3, 0.1, 0.4), (0.2, 0.2, 0.5)).toDF("a", "b", "c")
val filtered = tmp.filter(col("a") === 0.1)
  1. Exple 和 Pivot 地图类型
val df = filtered.map(row => maxN(row.getValuesMap(Seq("a", "b", "c")), 2)).toDF()

val exploded = df.select(explode($"value"))
+---+-----+
|key|value|
+---+-----+
|  a|  0.1|
|  b|  0.3|
+---+-----+

//Then pivot
exploded.withColumn("idx", lit(1))
      .groupBy($"idx").pivot($"key").agg(first($"value"))
      .drop("idx")
      .show()

+---+---+
|  b|  c|
+---+---+
|0.3|0.4|
+---+---+

【讨论】:

  • 你好,我更新了我的问题以包含一个索引列,你能更新这个答案吗?
  • 另外,如果我想要 min 2 而不是 min,如何更新 min 函数?
  • 我意识到如果有负值,这个例子就不起作用,因为它只看幅度。
猜你喜欢
  • 1970-01-01
  • 2020-09-01
  • 1970-01-01
  • 2018-06-09
  • 1970-01-01
  • 1970-01-01
  • 2020-05-02
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多