【问题标题】:How to write Spark UDF which takes Array[StructType], StructType as input and return Array[StructType]如何编写以 Array[StructType]、StructType 作为输入并返回 Array[StructType] 的 Spark UDF
【发布时间】:2020-06-28 06:18:33
【问题描述】:

我有一个具有以下架构的 DataFrame:

root
 |-- user_id: string (nullable = true)
 |-- user_loans_arr: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- loan_date: string (nullable = true)
 |    |    |-- loan_amount: string (nullable = true)
 |-- new_loan: struct (nullable = true)
 |    |-- loan_date : string (nullable = true)
 |    |-- loan_amount : string (nullable = true)

我想使用 UDF,它以 user_loans_arrnew_loan 作为输入,并将 new_loan 结构添加到现有的 user_loans_arr。然后,从 user_loans_arr 中删除所有 Loan_date 超过 12 个月的元素。

提前致谢。

【问题讨论】:

    标签: scala apache-spark user-defined-functions


    【解决方案1】:

    如果 spark >= 2.4 则不需要 UDF,请查看下面的示例-

    加载输入数据

     val df = spark.sql(
          """
            |select user_id, user_loans_arr, new_loan
            |from values
            | ('u1', array(named_struct('loan_date', '2019-01-01', 'loan_amount', 100)), named_struct('loan_date',
            | '2020-01-01', 'loan_amount', 100)),
            | ('u2', array(named_struct('loan_date', '2020-01-01', 'loan_amount', 200)), named_struct('loan_date',
            | '2020-01-01', 'loan_amount', 100))
            | T(user_id, user_loans_arr, new_loan)
          """.stripMargin)
        df.show(false)
        df.printSchema()
    
        /**
          * +-------+-------------------+-----------------+
          * |user_id|user_loans_arr     |new_loan         |
          * +-------+-------------------+-----------------+
          * |u1     |[[2019-01-01, 100]]|[2020-01-01, 100]|
          * |u2     |[[2020-01-01, 200]]|[2020-01-01, 100]|
          * +-------+-------------------+-----------------+
          *
          * root
          * |-- user_id: string (nullable = false)
          * |-- user_loans_arr: array (nullable = false)
          * |    |-- element: struct (containsNull = false)
          * |    |    |-- loan_date: string (nullable = false)
          * |    |    |-- loan_amount: integer (nullable = false)
          * |-- new_loan: struct (nullable = false)
          * |    |-- loan_date: string (nullable = false)
          * |    |-- loan_amount: integer (nullable = false)
          */
    

    按以下要求处理

    user_loans_arr 和 new_loan 作为输入,并将 new_loan 结构添加到现有的 user_loans_arr。然后,从 user_loans_arr 中删除所有 Loan_date 超过 12 个月的元素。

    spark >= 2.4

        df.withColumn("user_loans_arr",
          expr(
            """
              |FILTER(array_union(user_loans_arr, array(new_loan)),
              | x -> months_between(current_date(), to_date(x.loan_date)) < 12)
            """.stripMargin))
          .show(false)
    
        /**
          * +-------+--------------------------------------+-----------------+
          * |user_id|user_loans_arr                        |new_loan         |
          * +-------+--------------------------------------+-----------------+
          * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
          * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
          * +-------+--------------------------------------+-----------------+
          */
    

    spark &lt; 2.4

     // spark < 2.4
        val outputSchema = df.schema("user_loans_arr").dataType
    
        import java.time._
        val add_and_filter = udf((userLoansArr: mutable.WrappedArray[Row], loan: Row) => {
          (userLoansArr :+ loan).filter(row => {
            val loanDate = LocalDate.parse(row.getAs[String]("loan_date"))
            val period = Period.between(loanDate, LocalDate.now())
            period.getYears * 12 + period.getMonths < 12
          })
        }, outputSchema)
    
        df.withColumn("user_loans_arr", add_and_filter($"user_loans_arr", $"new_loan"))
          .show(false)
    
        /**
          * +-------+--------------------------------------+-----------------+
          * |user_id|user_loans_arr                        |new_loan         |
          * +-------+--------------------------------------+-----------------+
          * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
          * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
          * +-------+--------------------------------------+-----------------+
          */
    

    【讨论】:

    • 嗨 Someshwar Kale,感谢您的回答。它真的帮了我很多。你能帮我解决以下情况吗?对于任何用户,如果 user_loans_arr 为空并且该用户获得了 new_loan,我需要创建一个新的 user_loans_arr 数组并将 new_loan 添加到它。截至目前,我将该用户的 user_loans_arr 值设为 null。
    • 你能问那个问题吗?再次使用样本数据?顺便说一句,它应该在类似的线路上尝试一下
    • @Someshwar kale - 我不太了解 scala。但是在您的示例中,您似乎创建了一个包含结构数组的列。这是正确的吗? Pyspark 不允许我这样做。这是某种限制吗?
    • @Raghu,你能在这里引用声明吗?如果您正在谈论创建输入的第一条语句,那么它也应该在 pyspark 中工作,因为它是 spark sql
    • @SomeshwarKale - 是的,创建数据部分。我在 pyspark 中试过这个。 tst = sqlContext.createDataFrame([(1,2,3,4),(3,4,5,4),(5,6,7,5),(7,8,9,5)],schema=['col1','col2','col3','col4']) tst_s = tst.withColumn("test",F.struct('col1','col2')).withColumn("test2",F.struct('col2','col3')) tst_ar=tst_s.withColumn("arr",F.array('test','test2')) 然后我得到这个错误: u"cannot resolve 'array(test, test2)' 由于数据类型不匹配:函数数组的输入应该都是相同的类型,但它是 [struct, struct]
    【解决方案2】:

    您需要将数组和结构列作为数组或结构传递给 udf。我更喜欢将它作为结构传递。 在那里你可以操作元素并返回一个数组类型。

    import pyspark.sql.functions as F
    from pyspark.sql.functions import udf
    from pyspark.sql.types import *
    import numpy as np
    #Test data
    tst = sqlContext.createDataFrame([(1,2,3,4),(3,4,5,1),(5,6,7,8),(7,8,9,2)],schema=['col1','col2','col3','col4'])
    tst_1=(tst.withColumn("arr",F.array('col1','col2'))).withColumn("str",F.struct('col3','col4'))
    # udf to return array
    @udf(ArrayType(StringType()))
    def fn(row):
        if(row.arr[1]>row.str.col4):
            res=[]
        else:
            res.append(row.str[i])        
            res = row.arr+row.str.asDict().values()        
        return(res)
    # calling udf with a struct of array and struct column
    tst_fin = tst_1.withColumn("res",fn(F.struct('arr','str')))
    

    结果是

    tst_fin.show()
    +----+----+----+----+------+------+------------+
    |col1|col2|col3|col4|   arr|   str|         res|
    +----+----+----+----+------+------+------------+
    |   1|   2|   3|   4|[1, 2]|[3, 4]|[1, 2, 4, 3]|
    |   3|   4|   5|   1|[3, 4]|[5, 1]|          []|
    |   5|   6|   7|   8|[5, 6]|[7, 8]|[5, 6, 8, 7]|
    |   7|   8|   9|   2|[7, 8]|[9, 2]|          []|
    +----+----+----+----+------+------+----------
    

    此示例将所有内容都视为 int。由于您将字符串作为 date ,因此在您的 udf 中,您必须使用 python 的 datetime 函数进行比较。

    【讨论】:

    • 您好 Raghu,谢谢您的回答。
    猜你喜欢
    • 2017-08-13
    • 1970-01-01
    • 2017-11-24
    • 2018-10-21
    • 2017-12-13
    • 1970-01-01
    • 1970-01-01
    • 2021-03-08
    • 1970-01-01
    相关资源
    最近更新 更多