【问题标题】:Spark Scala: Count Consecutive MonthsSpark Scala:计算连续的月份
【发布时间】:2018-03-29 17:53:19
【问题描述】:

我有以下DataFrame 示例:

Provider  Patient  Date
Smith     John     2016-01-23
Smith     John     2016-02-20
Smith     John     2016-03-21
Smith     John     2016-06-25
Smith     Jill     2016-02-01
Smith     Jill     2016-03-10
James     Jill     2017-04-10
James     Jill     2017-05-11

我想以编程方式添加一个列,该列指示患者看医生的连续月数。新的DataFrame 如下所示:

Provider  Patient  Date         consecutive_id
Smith     John     2016-01-23   3
Smith     John     2016-02-20   3
Smith     John     2016-03-21   3
Smith     John     2016-06-25   1
Smith     Jill     2016-02-01   2
Smith     Jill     2016-03-10   2
James     Jill     2017-04-10   2
James     Jill     2017-05-11   2

我假设有一种方法可以通过 Window 函数实现此目的,但我还没有弄清楚,我期待社区可以提供的见解。谢谢。

【问题讨论】:

    标签: scala apache-spark spark-dataframe


    【解决方案1】:

    至少有3种方法可以得到结果

    1. 在 SQL 中实现逻辑
    2. 为窗口函数使用 Spark API - .over(windowSpec)
    3. 直接使用.rdd.mapPartitions

    Introducing Window Functions in Spark SQL

    对于所有解决方案,您都可以调用 .toDebugString 来查看幕后操作。

    SQL 解决方案如下

    val my_df = List(
      ("Smith", "John", "2016-01-23"),
      ("Smith", "John", "2016-02-20"),
      ("Smith", "John", "2016-03-21"),
      ("Smith", "John", "2016-06-25"),
      ("Smith", "Jill", "2016-02-01"),
      ("Smith", "Jill", "2016-03-10"),
      ("James", "Jill", "2017-04-10"),
      ("James", "Jill", "2017-05-11")
      ).toDF(Seq("Provider", "Patient", "Date"): _*)
    
    my_df.createOrReplaceTempView("tbl")
    
    val q = """
    select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id
      from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp
              from (select t0.*,
                           case
                              when cast(yyyymm as int) - 
                                   cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1
                              then 0
                              else 1
                           end x
                      from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2
    """
    
    sql(q).show
    sql(q).rdd.toDebugString
    

    输出

    scala> sql(q).show
    +--------+-------+----------+------+---+---+--------------+
    |Provider|Patient|      Date|yyyymm|  x|grp|consecutive_id|
    +--------+-------+----------+------+---+---+--------------+
    |   Smith|   Jill|2016-02-01|201602|  1|  1|             2|
    |   Smith|   Jill|2016-03-10|201603|  0|  1|             2|
    |   James|   Jill|2017-04-10|201704|  1|  1|             2|
    |   James|   Jill|2017-05-11|201705|  0|  1|             2|
    |   Smith|   John|2016-01-23|201601|  1|  1|             3|
    |   Smith|   John|2016-02-20|201602|  0|  1|             3|
    |   Smith|   John|2016-03-21|201603|  0|  1|             3|
    |   Smith|   John|2016-06-25|201606|  1|  2|             1|
    +--------+-------+----------+------+---+---+--------------+
    

    更新

    .mapPartitions + .over(windowSpec) 的混合

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
    
    val schema = new StructType().add(
                 StructField("provider", StringType, true)).add(
                 StructField("patient", StringType, true)).add(
                 StructField("date", StringType, true)).add(
                 StructField("x", IntegerType, true)).add(
                 StructField("grp", IntegerType, true))
    
    def f(iter: Iterator[Row]) : Iterator[Row] = {
      iter.scanLeft(Row("_", "_", "000000", 0, 0))
      {
        case (x1, x2) =>
    
        val x = 
        if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt ==
            x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1) 
        (0) else (1);
    
        val grp = x1.getInt(4) + x;
    
        Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp);
      }.drop(1)
    }
    
    val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient")
                                            .sortWithinPartitions($"date")
                                            .rdd.mapPartitions(f, true), schema)
    
    import org.apache.spark.sql.expressions.Window
    val windowSpec = Window.partitionBy($"provider", $"patient", $"grp")
    df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
         ).orderBy($"provider", $"patient", $"date").show
    

    输出

    scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
         |      ).orderBy($"provider", $"patient", $"date").show
    +--------+-------+----------+---+---+--------------+
    |provider|patient|      date|  x|grp|consecutive_id|
    +--------+-------+----------+---+---+--------------+
    |   James|   Jill|2017-04-10|  1|  1|             2|
    |   James|   Jill|2017-05-11|  0|  1|             2|
    |   Smith|   Jill|2016-02-01|  1|  1|             2|
    |   Smith|   Jill|2016-03-10|  0|  1|             2|
    |   Smith|   John|2016-01-23|  1|  1|             3|
    |   Smith|   John|2016-02-20|  0|  1|             3|
    |   Smith|   John|2016-03-21|  0|  1|             3|
    |   Smith|   John|2016-06-25|  1|  2|             1|
    +--------+-------+----------+---+---+--------------+
    

    【讨论】:

    • 这适用于我提供的示例数据,这就是为什么我很乐意打勾的原因。我只是在尝试show 最终的df_mod 转换时尝试通过java.lang.ArrayIndexOutOfBoundsException: 2
    【解决方案2】:

    你可以:

    1. 将日期重新格式化为整数 (2016-01 = 1, 2016-02 = 2, 2017-01 = 13 ...etc)
    2. 将所有日期组合成一个带有窗口和collect_list的数组:

      val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date") df.withColumn("Dates", collect_list("Date").over(winSpec))

    3. 将数组作为带有spark.udf.register 的UDF 传递到@marios solution 的修改版本中,以获得最大连续月数

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-05-30
      • 1970-01-01
      • 2021-12-26
      • 1970-01-01
      • 1970-01-01
      • 2020-12-19
      • 2011-06-25
      • 1970-01-01
      相关资源
      最近更新 更多