【问题标题】:Spark - Transforming Complex Data TypesSpark - 转换复杂的数据类型
【发布时间】:2019-10-04 08:28:15
【问题描述】:

目标

我想要达到的目标是

  • 读取 CSV 文件(确定)
  • 将其编码为Dataset<Person>,其中Person 对象有一个嵌套对象Address[](抛出异常)

个人 CSV 文件

在一个名为person.csv的文件中,有以下数据描述了一些人:

name,age,address
"name1",10,"streetA~cityA||streetB~cityB"
"name2",20,"streetA~cityA||streetB~cityB"

第一行是架构,地址是一个嵌套结构

数据类

数据类是:

@Data
public class Address implements Serializable {
    public String street;
    public String city;
}

@Data
public class Person implements Serializable {
    public String name;
    public Integer age;
    public Address[] address;
}

读取无类型数据

我首先尝试从 Dataset<Row> 中的 CSV 读取数据,结果按预期工作:

    Dataset<Row> ds = spark.read() //
                           .format("csv") //
                           .option("header", "true") // first line has headers
                           .load("src/test/resources/outer/person.csv");

    LOG.info("=============== Print schema =============");
    ds.printSchema();

root
|-- name: string (nullable = true)
|-- age: string (nullable = true)
|-- address: string (nullable = true)

    LOG.info("================ Print data ==============");
    ds.show();

+-----+---+--------------------+
| name|age|             address|
+-----+---+--------------------+
|name1| 10|streetA~cityA||st...|
|name2| 20|streetA~cityA||st...|
+-----+---+--------------------+

    LOG.info("================ Print name ==============");
    ds.select("name").show();

+-----+
| name|
+-----+
|name1|
|name2|
+-----+

    assertThat(ds.isEmpty(), is(false)); //OK
    assertThat(ds.count(), is(2L)); //OK
    final List<String> names = ds.select("name").as(Encoders.STRING()).collectAsList();
    assertThat(names, hasItems("name1", "name2")); //OK

通过 UserDefinedFunction 编码

我的 udf 接受 String 并返回 Address[]

private static void registerAsAddress(SparkSession spark) {
    spark.udf().register("asAddress", new UDF1<String, Address[]>() {

                             @Override
                             public Address[] call(String rowValue) {
                                 return Arrays.stream(rowValue.split(Pattern.quote("||"), -1)) //
                                              .map(object -> object.split("~")) //
                                              .map(Address::fromArgs) //
                                              .map(a -> a.orElse(null)) //
                                              .toArray(Address[]::new);
                             }
                         },  //
                         DataTypes.createArrayType(DataTypes.createStructType(
                                 new StructField[]{new StructField("street", DataTypes.StringType, true, Metadata.empty()), //
                                                   new StructField("city", DataTypes.StringType, true, Metadata.empty()) //
                                 })));
}

调用者:

   @Test
    void asAddressTest() throws URISyntaxException {
        registerAsAddress(spark);

        // given, when
        Dataset<Row> ds = spark.read() //
                               .format("csv") //
                               .option("header", "true") // first line has headers
                               .load("src/test/resources/outer/person.csv");

        ds.show();
        // create a typed dataset
        Encoder<Person> personEncoder = Encoders.bean(Person.class);
        Dataset<Person> typed = ds.withColumn("address2", //
                                                callUDF("asAddress", ds.col("address")))
                .drop("address").withColumnRenamed("address2", "address")
                .as(personEncoder);
        LOG.info("Typed Address");
        typed.show();
        typed.printSchema();
    }

这导致了这个执行:

原因:java.lang.IllegalArgumentException:值 (Address(street=streetA, city=cityA)) 的类型 (ch.project.data.Address) 不能 转换为结构体

为什么它不能从Address 转换为Struct

【问题讨论】:

  • 尝试将Address[] address;替换为List&lt;Address&gt; address;
  • 不幸的是同样的例外。还有什么想法吗?在我看来,Spark 无法推断架构。你怎么看?

标签: java apache-spark apache-spark-sql user-defined-functions


【解决方案1】:

在尝试了很多不同的方法并花了几个小时在互联网上进行研究之后,我得出了以下结论:

UserDefinedFunction 很好但是来自旧世界,它可以被一个简单的map() 函数替换,我们需要将对象从一种类型转换为另一种类型。 最简单的方法如下

    SparkSession spark = SparkSession.builder().appName("CSV to Dataset").master("local").getOrCreate();
    Encoder<FileFormat> fileFormatEncoder = Encoders.bean(FileFormat.class);
    Dataset<FileFormat> rawFile = spark.read() //
                                       .format("csv") //
                                       .option("inferSchema", "true") //
                                       .option("header", "true") // first line has headers
                                       .load("src/test/resources/encoding-tests/persons.csv") //
                                       .as(fileFormatEncoder);

    LOG.info("=============== Print schema =============");
    rawFile.printSchema();
    LOG.info("================ Print data ==============");
    rawFile.show();
    LOG.info("================ Print name ==============");
    rawFile.select("name").show();

    // when
    final SerializableFunction<String, List<Address>> asAddress = (String text) -> Arrays
            .stream(text.split(Pattern.quote("||"), -1)) //
            .map(object -> object.split("~")) //
            .map(Address::fromArgs) //
            .map(a -> a.orElse(null)).collect(Collectors.toList());

    final MapFunction<FileFormat, Person> personMapper = (MapFunction<FileFormat, Person>) row -> new Person(row.name,
                                                                                                             row.age,
                                                                                                             asAddress
                                                                                                                     .apply(row.address));
    final Encoder<Person> personEncoder = Encoders.bean(Person.class);
    Dataset<Person> persons = rawFile.map(personMapper, personEncoder);
    persons.show();

    // then
    assertThat(persons.isEmpty(), is(false));
    assertThat(persons.count(), is(2L));
    final List<String> names = persons.select("name").as(Encoders.STRING()).collectAsList();
    assertThat(names, hasItems("name1", "name2"));
    final List<Integer> ages = persons.select("age").as(Encoders.INT()).collectAsList();
    assertThat(ages, hasItems(10, 20));
    final Encoder<Address> addressEncoder = Encoders.bean(Address.class);
    final MapFunction<Person, Address> firstAddressMapper = (MapFunction<Person, Address>) person -> person.addresses.get(0);
    final List<Address> addresses = persons.map(firstAddressMapper, addressEncoder).collectAsList();
    assertThat(addresses, hasItems(new Address("streetA", "cityA"), new Address("streetC", "cityC")));

【讨论】:

    【解决方案2】:

    在你的 udf 中使用 Row 而不是 java 类

    public static UDF1<String, Row> personParseUdf = new UDF1<String, Row>() {
        @Override
        public Row call(String s) throws Exception {
            PersonEntity personEntity = PersonEntity.parse(s);
            List<Row> rowList = new ArrayList<>();
            for (AddressEntity addressEntity : personEntity.getAddress()) {
                //  use row instead of java class
                rowList.add(RowFactory.create(addressEntity.getStreet(), addressEntity.getCity()));
            }
            return RowFactory.create(personEntity.getName(), personEntity.getAge(), rowList);
        }
    };
    

    查看详情:https://cloud.tencent.com/developer/article/1674399

    【讨论】:

      猜你喜欢
      • 2018-10-29
      • 2018-06-30
      • 1970-01-01
      • 1970-01-01
      • 2023-02-25
      • 1970-01-01
      • 2021-10-12
      • 2013-10-30
      • 2021-12-25
      相关资源
      最近更新 更多