【问题标题】:Stratified sampling with Spark and Java使用 Spark 和 Java 进行分层采样
【发布时间】:2017-05-31 01:54:27
【问题描述】:

我想确保我正在对我的数据的分层样本进行训练。

here 所述,Spark 2.1 及更早版本似乎通过 JavaPairRDD.sampleByKey(...)JavaPairRDD.sampleByKeyExact(...) 支持此功能。

但是:我的数据存储在Dataset<Row>,而不是JavaPairRDD。第一列是标签,其他都是特征(从 libsvm 格式的文件导入)。

获取我的数据集实例的分层样本并最终再次获得Dataset<Row> 的最简单方法是什么?

在某种程度上,这个问题与Dealing with unbalanced datasets in Spark MLlib有关。

这个possible duplicate 根本没有提到Dataset<Row>,Java 中也没有。它没有回答我的问题。

【问题讨论】:

    标签: java apache-spark machine-learning apache-spark-mllib


    【解决方案1】:

    好的,因为the question here 的答案实际上并不是针对Java,我已经用Java 重写了它。

    推理还是一样的想法。我们仍在使用sampleByKeyExact。目前没有开箱即用的奇迹功能(spark 2.1.0

    所以你去:

    package org.awesomespark.examples;
    
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.PairFunction;
    import org.apache.spark.sql.*;
    import scala.Tuple2;
    
    import java.util.Map;
    
    public class StratifiedDatasets {
        public static void main(String[] args) {
            SparkSession spark = SparkSession.builder()
                    .appName("Stratified Datasets")
                    .getOrCreate();
    
            Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt");
    
            JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0));
            Map<Double, Double> fractions = rdd.map(Tuple2::_1)
                    .distinct()
                    .mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8))
                    .collectAsMap();
    
            JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values();
            Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema());
    
            sampledData.show();
            sampledData.printSchema();
        }
    }
    

    现在让我们打包并提交:

    $ sbt package
    [...]
    // [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM
    
    $ spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar 
    [...]
    
    // +-----+--------------------+
    // |label|            features|
    // +-----+--------------------+
    // |  0.0|(692,[127,128,129...|
    // |  1.0|(692,[158,159,160...|
    // |  1.0|(692,[124,125,126...|
    // |  1.0|(692,[152,153,154...|
    // |  1.0|(692,[151,152,153...|
    // |  0.0|(692,[129,130,131...|
    // |  1.0|(692,[99,100,101,...|
    // |  0.0|(692,[154,155,156...|
    // |  0.0|(692,[127,128,129...|
    // |  1.0|(692,[154,155,156...|
    // |  0.0|(692,[151,152,153...|
    // |  1.0|(692,[129,130,131...|
    // |  0.0|(692,[154,155,156...|
    // |  1.0|(692,[150,151,152...|
    // |  0.0|(692,[124,125,126...|
    // |  0.0|(692,[152,153,154...|
    // |  1.0|(692,[97,98,99,12...|
    // |  1.0|(692,[124,125,126...|
    // |  1.0|(692,[156,157,158...|
    // |  1.0|(692,[127,128,129...|
    // +-----+--------------------+
    // only showing top 20 rows
    
    // root
    //  |-- label: double (nullable = true)
    //  |-- features: vector (nullable = true)
    

    python用户也可以查看我的回答Stratified sampling with pyspark

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-05-18
      • 2015-11-21
      • 1970-01-01
      • 2018-11-22
      • 1970-01-01
      • 2022-09-30
      • 2014-09-08
      • 1970-01-01
      相关资源
      最近更新 更多