【问题标题】:How to get other columns when using Spark DataFrame groupby?使用 Spark DataFrame groupby 时如何获取其他列?
【发布时间】:2025-11-23 09:05:01
【问题描述】:

当我像这样使用 DataFrame groupby 时:

df.groupBy(df("age")).agg(Map("id"->"count"))

我只会得到一个包含“age”和“count(id)”列的DataFrame,但在df中,还有许多其他列,如“name”。

总而言之,我想得到 MySQL 中的结果,

"按年龄从df组中选择姓名、年龄、计数(id)"

在 Spark 中使用 groupby 应该怎么做?

【问题讨论】:

  • 为什么不使用“select name,age,count(id) from df group by age, name”呢?只有“group by age”会选择很多不同的名字,但只显示一个名字跨度>
  • 在我的问题中,我只是举一个简单的例子。当使用“按年龄分组,姓名”时,显然会得到与“按年龄分组”不同的结果......
  • 需要考虑的重要一点是:“当我按一个属性分组时,我需要另一列,我将使用另一列中的哪个值?”因为除非我们为另一列指定聚合技术,否则计算机不可能知道另一列中可能的多个其他值中的哪一个
  • "select name,age,count(id) from df group by age" 这看起来不像是有效的 SQL 语句。如果名称不是 group by 子句的一部分,则不能选择名称。

标签: sql apache-spark dataframe apache-spark-sql


【解决方案1】:

长话短说,一般来说,您必须将汇总结果与原始表格连接起来。 Spark SQL 遵循与大多数主要数据库(PostgreSQL、Oracle、MS SQL Server)相同的 pre-SQL:1999 约定,不允许在聚合查询中添加额外的列。

由于像计数结果这样的聚合没有很好地定义,并且在支持此类查询的系统中行为往往会有所不同,因此您可以使用任意聚合(如 firstlast)包含其他列。

在某些情况下,您可以使用 selectagg 替换为窗口函数和后续的 where,但根据上下文,它可能会非常昂贵。

【讨论】:

  • groupBy("columnName").agg(collect_list(struct("*")).as("newGroupedListColumn")) 怎么样?
  • 所以简短的回答就是使用第一个/最后一个聚合:df.groupBy(df("age")).agg(first("name"), first("some other col to includel"), ...)
  • 我使用了上面的评论,它对我有用。添加详细示例作为答案。
【解决方案2】:

执行 groupBy 后获取所有列的一种方法是使用 join 函数。

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joined 现在将包含包括计数值在内的所有列。

【讨论】:

  • 但是行会重复。
  • 考虑这个df.join(data_counts, feature_group).dropDuplicates()
【解决方案3】:

这个解决方案可能会有所帮助。

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

输出:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+

【讨论】:

    【解决方案4】:

    #solved #working 解决方案

    借助@Azmisov 在此线程中的评论生成了此解决方案 代码示例取自https://sparkbyexamples.com/spark/using-groupby-on-dataframe/

    问题:在使用数据帧的 spark scala 中,当使用 groupby 和 max 时,它返回的数据帧仅包含 groupby 和 max 中使用的列。如何获取所有列?或者可以说如何得到不分组列?

    解决方案:请通过完整示例获取所有包含 groupby 和 max 的列

    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._ //{col, lit, when, to_timestamp}
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.Column
    
    val spark = SparkSession
                  .builder()
                  .appName("app-name")
                  .master("local[*]")
                  .getOrCreate()
                  
        spark.sparkContext.setLogLevel("ERROR")
        import spark.implicits._
        
        val simpleData = Seq(("James","Sales","NY",90000,34,10000),
          ("Michael","Sales","NY",86000,56,20000),
          ("Robert","Sales","CA",81000,30,23000),
          ("Maria","Finance","CA",90000,24,23000),
          ("Raman","Finance","CA",99000,40,24000),
          ("Scott","Finance","NY",83000,36,19000),
          ("Jen","Finance","NY",79000,53,15000),
          ("Jeff","Marketing","CA",80000,25,18000),
          ("Kumar","Marketing","NY",91000,50,21000)
        )
        val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
        df.show()
    

    在生成 df 时给出以下输出。

    output :
    +-------------+----------+-----+------+---+-----+
    |employee_name|department|state|salary|age|bonus|
    +-------------+----------+-----+------+---+-----+
    |        James|     Sales|   NY| 90000| 34|10000|
    |      Michael|     Sales|   NY| 86000| 56|20000|
    |       Robert|     Sales|   CA| 81000| 30|23000|
    |        Maria|   Finance|   CA| 90000| 24|23000|
    |        Raman|   Finance|   CA| 99000| 40|24000|
    |        Scott|   Finance|   NY| 83000| 36|19000|
    |          Jen|   Finance|   NY| 79000| 53|15000|
    |         Jeff| Marketing|   CA| 80000| 25|18000|
    |        Kumar| Marketing|   NY| 91000| 50|21000|
    +-------------+----------+-----+------+---+-----+
    

    下面的代码给出了列名不合适但仍然可以使用的输出:

    val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
    dfwithmax.show()
    
    +----------+-----------+---------------------------+-------------------+-----------------+-------------------+
    |department|max(salary)|first(employee_name, false)|first(state, false)|first(age, false)|first(bonus, false)|
    +----------+-----------+---------------------------+-------------------+-----------------+-------------------+
    |     Sales|      90000|                      James|                 NY|               34|              10000|
    |   Finance|      99000|                      Maria|                 CA|               24|              23000|
    | Marketing|      91000|                       Jeff|                 CA|               25|              18000|
    +----------+-----------+---------------------------+-------------------+-----------------+-------------------+
    

    为了使列名合适,您可以使用下面给出的列名

    val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
    dfwithmax1.show()
    
    output:
    +----------+------+-------------+-----+---+-----+
    |department|salary|employee_name|state|age|bonus|
    +----------+------+-------------+-----+---+-----+
    |     Sales| 90000|        James|   NY| 34|10000|
    |   Finance| 99000|        Maria|   CA| 24|23000|
    | Marketing| 91000|         Jeff|   CA| 25|18000|
    +----------+------+-------------+-----+---+-----+
    

    如果您仍想更改数据框列的顺序,可以按如下方式完成

    val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
    val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
    orderedDf.show()
    

    完整的代码:

    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.Column
    
    object test {
        def main(args: Array[String]): Unit = {
            /** spark session object */
            val spark = SparkSession.builder().appName("app-name").master("local[*]")
                      .getOrCreate()
                      
            spark.sparkContext.setLogLevel("ERROR")
            import spark.implicits._
            
            val simpleData = Seq(("James","Sales","NY",90000,34,10000),
              ("Michael","Sales","NY",86000,56,20000),
              ("Robert","Sales","CA",81000,30,23000),
              ("Maria","Finance","CA",90000,24,23000),
              ("Raman","Finance","CA",99000,40,24000),
              ("Scott","Finance","NY",83000,36,19000),
              ("Jen","Finance","NY",79000,53,15000),
              ("Jeff","Marketing","CA",80000,25,18000),
              ("Kumar","Marketing","NY",91000,50,21000)
            )
            val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
            df.show()
    
            val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
            dfwithmax.show()
            
            val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
            dfwithmax1.show()
            
            val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
            val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
            orderedDf.show()
        }
    }
    
    full output :
    +-------------+----------+-----+------+---+-----+
    |employee_name|department|state|salary|age|bonus|
    +-------------+----------+-----+------+---+-----+
    |        James|     Sales|   NY| 90000| 34|10000|
    |      Michael|     Sales|   NY| 86000| 56|20000|
    |       Robert|     Sales|   CA| 81000| 30|23000|
    |        Maria|   Finance|   CA| 90000| 24|23000|
    |        Raman|   Finance|   CA| 99000| 40|24000|
    |        Scott|   Finance|   NY| 83000| 36|19000|
    |          Jen|   Finance|   NY| 79000| 53|15000|
    |         Jeff| Marketing|   CA| 80000| 25|18000|
    |        Kumar| Marketing|   NY| 91000| 50|21000|
    +-------------+----------+-----+------+---+-----+
    
    +----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
    |department|max(salary)|first(employee_name, false)|first(department, false)|first(state, false)|first(age, false)|first(bonus, false)|
    +----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
    |     Sales|      90000|                      James|                   Sales|                 NY|               34|              10000|
    |   Finance|      99000|                      Maria|                 Finance|                 CA|               24|              23000|
    | Marketing|      91000|                       Jeff|               Marketing|                 CA|               25|              18000|
    +----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
    
    +----------+------+-------------+----------+-----+---+-----+
    |department|salary|employee_name|department|state|age|bonus|
    +----------+------+-------------+----------+-----+---+-----+
    |     Sales| 90000|        James|     Sales|   NY| 34|10000|
    |   Finance| 99000|        Maria|   Finance|   CA| 24|23000|
    | Marketing| 91000|         Jeff| Marketing|   CA| 25|18000|
    +----------+------+-------------+----------+-----+---+-----+
    

    例外:

    Exception in thread "main" org.apache.spark.sql.AnalysisException: Reference 'department' is ambiguous, could be: department, department.;
    

    这意味着您有两次部门列。它在 groupby 或 max 中使用,并且您在 first("department") 中也将其称为“department”。

    例如(请最后检查):

    val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("department") as "department", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
    

    谢谢!有帮助请点赞。

    【讨论】:

    • 对于财务组,您希望名称为 Raman(薪水最高),但取而代之的是不匹配的名字 Maria
    • 这是一个非常详细的解决方案!
    【解决方案5】:

    聚合函数减少组内指定列的行值。如果您希望保留其他行值,则需要实现缩减逻辑,以指定每个值来自的行。例如,将第一行的所有值保留为年龄的最大值。为此,您可以使用 UDAF(用户定义的聚合函数)来减少组内的行数。

    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    
    
    object AggregateKeepingRowJob {
    
      def main (args: Array[String]): Unit = {
    
        val sparkSession = SparkSession
          .builder()
          .appName(this.getClass.getName.replace("$", ""))
          .master("local")
          .getOrCreate()
    
        val sc = sparkSession.sparkContext
        sc.setLogLevel("ERROR")
    
        import sparkSession.sqlContext.implicits._
    
        val rawDf = Seq(
          (1L, "Moe",  "Slap",  2.0, 18),
          (2L, "Larry",  "Spank",  3.0, 15),
          (3L, "Curly",  "Twist", 5.0, 15),
          (4L, "Laurel", "Whimper", 3.0, 15),
          (5L, "Hardy", "Laugh", 6.0, 15),
          (6L, "Charley",  "Ignore",   5.0, 5)
        ).toDF("id", "name", "requisite", "money", "age")
    
        rawDf.show(false)
        rawDf.printSchema
    
        val maxAgeUdaf = new KeepRowWithMaxAge
    
        val aggDf = rawDf
          .groupBy("age")
          .agg(
            count("id"),
            max(col("money")),
            maxAgeUdaf(
              col("id"),
              col("name"),
              col("requisite"),
              col("money"),
              col("age")).as("KeepRowWithMaxAge")
          )
    
        aggDf.printSchema
        aggDf.show(false)
    
      }
    
    
    }
    

    UDAF:

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
    // This is the input fields for your aggregate function.
    override def inputSchema: org.apache.spark.sql.types.StructType =
      StructType(
        StructField("store", StringType) ::
        StructField("prod", StringType) ::
        StructField("amt", DoubleType) ::
        StructField("units", IntegerType) :: Nil
      )
    
    // This is the internal fields you keep for computing your aggregate.
    override def bufferSchema: StructType = StructType(
      StructField("store", StringType) ::
      StructField("prod", StringType) ::
      StructField("amt", DoubleType) ::
      StructField("units", IntegerType) :: Nil
    )
    
    
    // This is the output type of your aggregation function.
    override def dataType: DataType =
      StructType((Array(
        StructField("store", StringType),
        StructField("prod", StringType),
        StructField("amt", DoubleType),
        StructField("units", IntegerType)
      )))
    
    override def deterministic: Boolean = true
    
    // This is the initial value for your buffer schema.
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = ""
      buffer(1) = ""
      buffer(2) = 0.0
      buffer(3) = 0
    }
    
    // This is how to update your buffer schema given an input.
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
      val amt = buffer.getAs[Double](2)
      val candidateAmt = input.getAs[Double](2)
    
      amt match {
        case a if a < candidateAmt =>
          buffer(0) = input.getAs[String](0)
          buffer(1) = input.getAs[String](1)
          buffer(2) = input.getAs[Double](2)
          buffer(3) = input.getAs[Int](3)
        case _ =>
      }
    }
    
    // This is how to merge two objects with the bufferSchema type.
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    
      buffer1(0) = buffer2.getAs[String](0)
      buffer1(1) = buffer2.getAs[String](1)
      buffer1(2) = buffer2.getAs[Double](2)
      buffer1(3) = buffer2.getAs[Int](3)
    }
    
    // This is where you output the final value, given the final value of your bufferSchema.
    override def evaluate(buffer: Row): Any = {
      buffer
    }
    }
    

    【讨论】:

      【解决方案6】:

      这是我遇到的一个例子 spark-workshop

      val populationDF = spark.read
                      .option("infer-schema", "true")
                      .option("header", "true")
                      .format("csv").load("file:///databricks/driver/population.csv")
                      .select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))
      

      val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

      为了获取其他列,我在原始 DF 和聚合的 DF 之间进行了简单连接

      populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()
      

      【讨论】:

      • 链接指向非英语语言的内容。*
      【解决方案7】:

      此 pyspark 代码选择每个 A-group 的 max([A, B]-combination) 的 B 值(如果一组中存在多个最大值,则随机选择一个) .

      A 在您的情况下是 ageB 任何您没有分组但仍想选择的列。

      df = spark.createDataFrame([
          [1, 1, 0.2],
          [1, 1, 0.9],
          [1, 2, 0.6],
          [1, 2, 0.5],
          [1, 2, 0.6],
          [2, 1, 0.2],
          [2, 2, 0.1],
      ], ["group", "A", "B"])
      
      out = (
          df
          .withColumn("AB", F.struct("A", "B"))
          .groupby("group")
          # F.max(AB) selects AB-combinations with max `A`. If more
          # than one combination remains the one with max `B` is selected. If
          # after this identical combinations remain, a single one of them is picked
          # randomly.
          .agg(F.max("AB").alias("max_AB"))
          .select("group", F.expr("max_AB.B"))
      )
      out.show()
      

      输出

      +-----+---+
      |group|  B|
      +-----+---+
      |    1|0.6|
      |    2|0.1|
      +-----+---+
      

      【讨论】:

        【解决方案8】:

        您需要记住聚合函数会减少行数,因此您需要使用减少函数指定要使用的行名称。如果您想保留组的所有行(警告!这可能导致爆炸或分区倾斜),您可以将它们收集为一个列表。然后,您可以使用 UDF(用户定义的函数)根据您的标准减少它们,在我的示例中为金钱。然后使用另一个 UDF 扩展单个缩减行中的列。 出于此答案的目的,我假设您希望保留最有钱的人的名字。

        import org.apache.spark.sql._
        import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
        import org.apache.spark.sql.functions._
        import org.apache.spark.sql.types.StringType
        
        import scala.collection.mutable
        
        
        object TestJob3 {
        
        def main (args: Array[String]): Unit = {
        
        val sparkSession = SparkSession
          .builder()
          .appName(this.getClass.getName.replace("$", ""))
          .master("local")
          .getOrCreate()
        
        val sc = sparkSession.sparkContext
        
        import sparkSession.sqlContext.implicits._
        
        val rawDf = Seq(
          (1, "Moe",  "Slap",  2.0, 18),
          (2, "Larry",  "Spank",  3.0, 15),
          (3, "Curly",  "Twist", 5.0, 15),
          (4, "Laurel", "Whimper", 3.0, 9),
          (5, "Hardy", "Laugh", 6.0, 18),
          (6, "Charley",  "Ignore",   5.0, 5)
        ).toDF("id", "name", "requisite", "money", "age")
        
        rawDf.show(false)
        rawDf.printSchema
        
        val rawSchema = rawDf.schema
        
        val fUdf = udf(reduceByMoney, rawSchema)
        
        val nameUdf = udf(extractName, StringType)
        
        val aggDf = rawDf
          .groupBy("age")
          .agg(
            count(struct("*")).as("count"),
            max(col("money")),
            collect_list(struct("*")).as("horizontal")
          )
          .withColumn("short", fUdf($"horizontal"))
          .withColumn("name", nameUdf($"short"))
          .drop("horizontal")
        
        aggDf.printSchema
        
        aggDf.show(false)
        
        }
        
        def reduceByMoney= (x: Any) => {
        
        val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
        
        val red = d.reduce((r1, r2) => {
        
          val money1 = r1.getAs[Double]("money")
          val money2 = r2.getAs[Double]("money")
        
          val r3 = money1 match {
            case a if a >= money2 =>
              r1
            case _ =>
              r2
          }
        
          r3
        })
        
        red
        }
        
        def extractName = (x: Any) => {
        
          val d = x.asInstanceOf[GenericRowWithSchema]
        
          d.getAs[String]("name")
        }
        }
        

        这是输出

        +---+-----+----------+----------------------------+-------+
        |age|count|max(money)|short                       |name   |
        +---+-----+----------+----------------------------+-------+
        |5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
        |15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
        |9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
        |18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
        +---+-----+----------+----------------------------+-------+
        

        【讨论】:

          【解决方案9】:

          你可以这样做:

          样本数据:

          name    age id
          abc     24  1001
          cde     24  1002
          efg     22  1003
          ghi     21  1004
          ijk     20  1005
          klm     19  1006
          mno     18  1007
          pqr     18  1008
          rst     26  1009
          tuv     27  1010
          pqr     18  1012
          rst     28  1013
          tuv     29  1011
          
          df.select("name","age","id").groupBy("name","age").count().show();
          

          输出:

              +----+---+-----+
              |name|age|count|
              +----+---+-----+
              | efg| 22|    1|
              | tuv| 29|    1|
              | rst| 28|    1|
              | klm| 19|    1|
              | pqr| 18|    2|
              | cde| 24|    1|
              | tuv| 27|    1|
              | ijk| 20|    1|
              | abc| 24|    1|
              | mno| 18|    1|
              | ghi| 21|    1|
              | rst| 26|    1|
              +----+---+-----+
          

          【讨论】:

          • 这个答案不正确,因为主题启动器只需要按 1 个字段分组
          • 其实我觉得这个答案有一定的可取之处;至少对于我的用例,两个属性的分组是可以接受的。它遵循“原始”的 SQL 意图
          最近更新 更多