您可以尝试的一种方法是使用 VectorUDT 的内部 _sqlType(见下文)构造 JSON 字符串,然后使用 from_json 函数:
struct<type:tinyint,size:int,indices:array<int>,values:array<double>>
首先,删除前导括号和尾括号,并使用正则表达式,\s*(?=\[) 拆分字符串。这会将 StringType 列转换为名为 s 的 ArrayType 列,其中 3 个项目对应于 s[0]=size、s[1]=indices 和 s[2]=values,然后使用concat函数创建JSON字符串。
注意:在本例中,我们用 array 包裹 VectorUDT,因为 from_json 函数只采用其中一种复杂数据类型:array , map 或 struct。您也可以尝试使用 map 或 struct 作为包装器。
from pyspark.ml.linalg import VectorUDT
from pyspark.sql.types import ArrayType
from pyspark.sql.functions import expr, from_json
df = spark.createDataFrame([('(174, [7, 10, 56, 89, 156], [1.0, 1.0, 1.0, 1.0, 1.0])',)],['column'])
# DataFrame[column: string]
df_new = df.withColumn("s", expr("split(substr(column,2,length(column)-2), ',\\\\s*(?=\\\\[)')")) \
.selectExpr("""
concat(
/* type = 0 for SparseVector and type = 1 for DenseVector */
'[{"type":0,"size":',
s[0],
',"indices":',
s[1],
',"values":',
s[2],
'}]'
) as vec_json
""") \
.withColumn('features', from_json('vec_json', ArrayType(VectorUDT()))[0])
注意我们这里构造的 JSON 字符串是一个带有单个 VectorUDT 的数组,然后我们可以使用 from_json 和 getItem(0) 来检索向量。
结果:
df_new.printSchema()
root
|-- vec_json: string (nullable = true)
|-- features: vector (nullable = true)
df_new.show(truncate=False, vertical=True)
-RECORD 0---------------------------------------------------------------------------------------------
vec_json | [{"type":0,"size":174,"indices":[7, 10, 56, 89, 156],"values":[1.0, 1.0, 1.0, 1.0, 1.0]}]
features | (174,[7,10,56,89,156],[1.0,1.0,1.0,1.0,1.0])