【问题标题】:Transform spark string column to vectorUDT将 spark 字符串列转换为 vectorUDT
【发布时间】:2020-02-28 10:58:42
【问题描述】:

我正在使用 pyspark。
首先,我正在阅读带有字符串列的 csv。里面的数据是这样的:

(174, [7, 10, 56, 89, 156], [1.0, 1.0, 1.0, 1.0, 1.0])

我需要将其转换为 VectorUDT,以将该列作为机器学习算法的输入。

我已经尝试投专栏了:

data = data.withColumn(column, data[column].cast(VectorUDT())

但它不起作用......

你有什么解决办法吗?

【问题讨论】:

    标签: python dataframe apache-spark pyspark


    【解决方案1】:

    您可以尝试的一种方法是使用 VectorUDT 的内部 _sqlType(见下文)构造 JSON 字符串,然后使用 from_json 函数:

    struct<type:tinyint,size:int,indices:array<int>,values:array<double>> 
    

    首先,删除前导括号和尾括号,并使用正则表达式,\s*(?=\[) 拆分字符串。这会将 StringType 列转换为名为 s 的 ArrayType 列,其中 3 个项目对应于 s[0]=sizes[1]=indices s[2]=values,然后使用concat函数创建JSON字符串。

    注意:在本例中,我们用 array 包裹 VectorUDT,因为 from_json 函数只采用其中一种复杂数据类型:array , mapstruct。您也可以尝试使用 mapstruct 作为包装器。

    from pyspark.ml.linalg import VectorUDT
    from pyspark.sql.types import ArrayType
    from pyspark.sql.functions import expr, from_json
    
    df = spark.createDataFrame([('(174, [7, 10, 56, 89, 156], [1.0, 1.0, 1.0, 1.0, 1.0])',)],['column'])
    # DataFrame[column: string]
    
    df_new = df.withColumn("s", expr("split(substr(column,2,length(column)-2), ',\\\\s*(?=\\\\[)')")) \
      .selectExpr("""
          concat(
            /* type = 0 for SparseVector and type = 1 for DenseVector */
            '[{"type":0,"size":',
            s[0],
            ',"indices":',
            s[1],
            ',"values":',
            s[2],
            '}]'
          ) as vec_json
       """) \
      .withColumn('features', from_json('vec_json', ArrayType(VectorUDT()))[0])
    

    注意我们这里构造的 JSON 字符串是一个带有单个 VectorUDT 的数组,然后我们可以使用 from_jsongetItem(0) 来检索向量。

    结果:

    df_new.printSchema()
    root
     |-- vec_json: string (nullable = true)
     |-- features: vector (nullable = true)
    
    df_new.show(truncate=False, vertical=True)
    -RECORD 0---------------------------------------------------------------------------------------------
     vec_json | [{"type":0,"size":174,"indices":[7, 10, 56, 89, 156],"values":[1.0, 1.0, 1.0, 1.0, 1.0]}]
     features | (174,[7,10,56,89,156],[1.0,1.0,1.0,1.0,1.0])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-11-25
      • 2016-12-06
      • 2022-01-17
      • 2021-03-26
      • 2017-06-03
      • 2017-11-03
      • 1970-01-01
      • 2018-07-29
      相关资源
      最近更新 更多