【问题标题】:Applying UDF while Joining two Spark dataframes with intervals在以间隔连接两个 Spark 数据帧时应用 UDF
【发布时间】:2019-02-01 12:59:10
【问题描述】:

我有一个包含三列的数据框:idindexvalue

+---+-----+-------------------+
| id|index|              value|
+---+-----+-------------------+
|  A| 1023|0.09938822262205915|
|  A| 1046| 0.3110047630613805|
|  A| 1069| 0.8486710971453512|
+---+-----+-------------------+

root
 |-- id: string (nullable = true)
 |-- index: integer (nullable = false)
 |-- value: double (nullable = false)

然后,我有另一个数据框,显示每个 id 的理想周期:

+---+-----------+---------+
| id|start_index|end_index|
+---+-----------+---------+
|  A|       1069|     1276|
|  B|       2066|     2291|
|  B|       1616|     1841|
|  C|       3716|     3932|
+---+-----------+---------+

root
 |-- id: string (nullable = true)
 |-- start_index: integer (nullable = false)
 |-- end_index: integer (nullable = false)

我有如下三个模板

val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)

目标是,对于dfIntervals 中的每一行,应用一个函数(假设它是相关的),其中该函数接收来自dfRawvalue 列和三个模板数组并将三列添加到dfIntervals,与每个模板相关的每一列。

假设: 1 - 模板数组的大小正好是 10。

2 - dfRawindex 列中没有重复项

3 - dfIntervals 中的 start_indexend_index 列存在于 dfRawindex 列中,并且它们之间正好有 10 行。例如,dfRaw.filter($"id" === "A").filter($"index" >= 1069 && $"index" <= 1276).count(dfIntervals 中的第一行)的结果正好是 10

以下是生成这些数据帧的代码:

import org.apache.spark.sql.functions._
val mySeed = 1000

/* Defining templates for correlation analysis*/
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)

/* Defining raw data*/
var dfRaw = Seq(
  ("A", (1023 to 1603 by 23).toArray),
  ("B", (341 to 2300 by 25).toArray),
  ("C", (2756 to 3954 by 24).toArray)
).toDF("id", "index")
dfRaw = dfRaw.select($"id", explode($"index") as "index").withColumn("value", rand(seed=mySeed))

/* Defining intervals*/
var dfIntervals = Seq(
  ("A", 1069, 1276),
  ("B", 2066, 2291),
  ("B", 1616, 1841),
  ("C", 3716, 3932)
).toDF("id", "start_index", "end_index")

结果是dfIntervals 数据框添加了三列,名称为corr_w_template1corr_w_template2corr_w_template3

PS:我在 Scala 中找不到相关函数。假设存在这样的函数(如下所示),并且我们即将使用它生成一个udf

def correlation(arr1: Array[Double], arr2: Array[Double]): Double

【问题讨论】:

  • 据我了解,您需要如下 udf:def correlation(value: Double, template: Array[Double]): Double 其中template 可以是以下值之一:template1template2template3value 来自 dfRaw 。对吗?
  • 没错。我想这些函数需要应用三次才能获得原始信号和每个模板之间的相关性。
  • dfIntervals 包含start_index & end_index 所以应该是correlation(values: Array[Double], template: Array[Double]): Double 对吧?其中values 是从dfRaw 获得的,其中index 在[start_index: end_index] 范围内
  • 是的,没错。

标签: scala apache-spark dataframe user-defined-functions


【解决方案1】:

好的。

让我们定义一个 UDF 函数。

出于测试目的,假设它总是返回 1。

 val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {

     1f
  })

val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
    values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
  })

然后让我们将您的 2 个数据框与定义的规则合并,并将 value 收集到名为 values 的 1 列中。另外,请申请我们的orderUdf

 val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index")  <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
    .groupBy(dfIntervals("id"), dfIntervals("start_index"),  dfIntervals("end_index"))
    .agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))

最后,应用我们的 udf 并展示出来。

df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
    .withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
    .withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
    .show(10)

这是完整的示例代码:

import org.apache.spark.sql.functions._
  import scala.collection.JavaConverters._

  val conf = new SparkConf().setAppName("learning").setMaster("local[2]")

  val session = SparkSession.builder().config(conf).getOrCreate()



  val mySeed = 1000

  /* Defining templates for correlation analysis*/
  val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
  val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
  val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)

  val schema1 =  DataTypes.createStructType(Array(
    DataTypes.createStructField("id",DataTypes.StringType,false),
    DataTypes.createStructField("index",DataTypes.createArrayType(DataTypes.IntegerType),false)
  ))

  val schema2 =  DataTypes.createStructType(Array(
    DataTypes.createStructField("id",DataTypes.StringType,false),
    DataTypes.createStructField("start_index",DataTypes.IntegerType,false),
    DataTypes.createStructField("end_index",DataTypes.IntegerType,false)
  ))

  /* Defining raw data*/
  var dfRaw = session.createDataFrame(Seq(
    ("A", (1023 to 1603 by 23).toArray),
    ("B", (341 to 2300 by 25).toArray),
    ("C", (2756 to 3954 by 24).toArray)
  ).map(r => Row(r._1 , r._2)).asJava, schema1)

  dfRaw = dfRaw.select(dfRaw("id"), explode(dfRaw("index")) as "index")
    .withColumn("value", rand(seed=mySeed))

  /* Defining intervals*/
  var dfIntervals =  session.createDataFrame(Seq(
    ("A", 1069, 1276),
    ("B", 2066, 2291),
    ("B", 1616, 1841),
    ("C", 3716, 3932)
  ).map(r => Row(r._1 , r._2,r._3)).asJava, schema2)

  //Define udf

  val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
     1f
  })

  val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
    values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
  })


  val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index")  <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
    .groupBy(dfIntervals("id"), dfIntervals("start_index"),  dfIntervals("end_index"))
    .agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))

  df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
    .withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
    .withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
    .show(10,false)

【讨论】:

  • 我认为应该在join条件中加入dfRaw("id") === dfIntervals("id")dfIntervals.join(dfRaw, dfIntervals("start_index") &lt;= dfRaw("index") &amp;&amp; dfRaw("index") &lt;= dfIntervals( "end_index") &amp;&amp; dfRaw("id") === dfIntervals("id")),否则会有其他id的污染。当我们这样做时,值顺序会颠倒!我们需要包含 indexWin collect_list 并使用排序器 udf 来确保列表中的值是有序的。 Scala 版本的 here
  • 是的。那是我的错误。我已经更新了解决方案
  • 太棒了。感谢更新。有没有办法确保values 列已经在df 中排序?虽然这可以在 udf 函数中完成,但这样做类似于 this Python solution
  • 我的意思是根据索引对值进行排序。目前它是根据值排序的!也许使用collect_list(struct(dfRaw("index"),dfRaw("value")) 然后使用(key, xs) =&gt; (key, xs.map(_.c2).toSeq.sortBy(_._2))) 沿线的地​​图可以做到这一点?
  • 是的。太好了
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2017-08-16
  • 2018-09-13
  • 1970-01-01
  • 2018-06-02
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多