【问题标题】:DataFrame equality in Apache SparkApache Spark 中的 DataFrame 相等性
【发布时间】:2015-09-20 17:41:48
【问题描述】:

假设 df1df2 是 Apache Spark 中的两个 DataFrames,使用两种不同的机制计算得出,例如 Spark SQL 与 Scala/Java/Python API。

是否有一种惯用的方法来确定两个数据帧是否等价(相等,同构),其中等价性取决于数据(每行的列名和列值)是否相同,除了行和列的排序?

提出这个问题的动机是,通常有很多方法可以计算一些大数据结果,每种方法都有自己的取舍。在探索这些权衡时,保持正确性很重要,因此需要在有意义的测试数据集上检查等价/相等性。

【问题讨论】:

    标签: scala apache-spark dataframe apache-spark-sql rdd


    【解决方案1】:

    Apache Spark 测试套件中有一些标准方法,但其中大多数涉及在本地收集数据,如果您想对大型 DataFrame 进行相等性测试,那么这可能不是一个合适的解决方案。

    首先检查架构,然后您可以对 df3 进行交集并验证 df1、df2 和 df3 的计数是否全部相等(但是,这仅在没有重复行的情况下才有效,如果有不同的重复行方法仍然可以返回 true)。

    另一种选择是获取两个 DataFrame 的底层 RDD,映射到 (Row, 1),执行 reduceByKey 来计算每个 Row 的数量,然后将两个生成的 RDD 组合在一起,然后进行常规聚合和如果任何迭代器不相等,则返回 false。

    【讨论】:

    • 使用测试套件是一个有趣的想法。收集数据可能是中小型数据集的一种选择。那里的标准工具是什么?
    • 在重复行的情况下,如何附加一个额外的'count'列(当然通过计算functions.agg或通过SQL)然后得到相交为df3?
    • 那么如何取两个数据集的联合,然后按所有列分组(当然使用序列)并计数,过滤计数%2。如果大于 0 则返回 false。联合比交叉快,如果列不同,将返回异常(纠正我,如果我错了)
    • 我不认为这会更快,交叉点的缓慢部分是你也可以使用 groupBy 的随机播放。
    【解决方案2】:

    我不知道惯用语,但我认为您可以获得一种比较 DataFrames 的可靠方法,如下所述。 (我使用 PySpark 进行说明,但该方法可以跨语言使用。)

    a = spark.range(5)
    b = spark.range(5)
    
    a_prime = a.groupBy(sorted(a.columns)).count()
    b_prime = b.groupBy(sorted(b.columns)).count()
    
    assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
    

    这种方法可以正确处理 DataFrame 可能具有重复行、不同顺序的行和/或不同顺序的列的情况。

    例如:

    a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
    b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
    c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])
    
    a_prime = a.groupBy(sorted(a.columns)).count()
    b_prime = b.groupBy(sorted(b.columns)).count()
    c_prime = c.groupBy(sorted(c.columns)).count()
    
    assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
    assert a_prime.subtract(c_prime).count() != 0
    

    这种方法非常昂贵,但考虑到需要执行完整的差异,大部分费用是不可避免的。这应该可以很好地扩展,因为它不需要在本地收集任何东西。如果您放宽比较应该考虑重复行的约束,那么您可以删除groupBy(),而只执行subtract(),这可能会显着加快速度。

    【讨论】:

    • 请注意,这不适用于任何不可排序的数据类型,例如地图,在这种情况下,您可能必须删除这些列并单独执行。
    • 我猜 count 是在里面和 agg() 方法,否则 a_prime、b_prime 和 c_prime 是数字而不是数据帧
    • @dhalfageme - 不,.count()GroupedData 对象上——这是.groupBy() 返回的——产生一个DataFrame。试试看:spark.range(3).groupBy('id').count().show()
    【解决方案3】:

    Scala(PySpark 见下文)

    spark-fast-tests 库有两种进行 DataFrame 比较的方法(我是该库的创建者):

    assertSmallDataFrameEquality方法收集驱动节点上的DataFrame并进行比较

    def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
      if (!actualDF.schema.equals(expectedDF.schema)) {
        throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
      }
      if (!actualDF.collect().sameElements(expectedDF.collect())) {
        throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
      }
    }
    

    assertLargeDataFrameEquality方法比较分布在多台机器上的DataFrames(代码基本是从spark-testing-base复制过来的)

    def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
      if (!actualDF.schema.equals(expectedDF.schema)) {
        throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
      }
      try {
        actualDF.rdd.cache
        expectedDF.rdd.cache
    
        val actualCount = actualDF.rdd.count
        val expectedCount = expectedDF.rdd.count
        if (actualCount != expectedCount) {
          throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
        }
    
        val expectedIndexValue = zipWithIndex(actualDF.rdd)
        val resultIndexValue = zipWithIndex(expectedDF.rdd)
    
        val unequalRDD = expectedIndexValue
          .join(resultIndexValue)
          .filter {
            case (idx, (r1, r2)) =>
              !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
          }
    
        val maxUnequalRowsToShow = 10
        assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
    
      } finally {
        actualDF.rdd.unpersist()
        expectedDF.rdd.unpersist()
      }
    }
    

    assertSmallDataFrameEquality 对于小型 DataFrame 比较更快,我发现它对于我的测试套件来说已经足够了。

    PySpark

    这是一个简单的函数,如果 DataFrame 相等则返回 true:

    def are_dfs_equal(df1, df2):
        if df1.schema != df2.schema:
            return False
        if df1.collect() != df2.collect():
            return False
        return True
    

    您通常会在测试套件中执行 DataFrame 相等性比较,并且在比较失败时需要描述性错误消息(True / False 返回值在调试时没有多大帮助)。

    使用chispa 库访问assert_df_equality 方法,该方法返回测试套件工作流的描述性错误消息。

    【讨论】:

    • 看起来不错的图书馆!
    • @Powers,你知道 pySpark 而不是 Scala 的类似库吗?
    • @jgtrz - 我开始构建 PySpark 版本的 spark-fast-tests,称为 chispa:github.com/MrPowers/chispa。需要完成它!
    • 对于我们这些在这里偶然发现并实施 collect 的人,请与 !actualDF.collect().sameElements(expectedDF.collect()) 进行比较。请注意下面的帖子并警惕sameElements()stackoverflow.com/questions/29008500/…的可笑性
    • 对于 Pyspark 人员:提供的功能考虑了排序。如果您只关心内容,请将第二个条件替换为:if df1.orderBy(*df1.columns).collect() !=df2.orderBy(*df2.columns).collect():
    【解决方案4】:

    您可以结合使用一点重复数据删除和完全外连接来完成此操作。这种方法的优点是它不需要您将结果收集到驱动程序,并且可以避免运行多个作业。

    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    
    // Generate some random data.
    def random(n: Int, s: Long) = {
      spark.range(n).select(
        (rand(s) * 10000).cast("int").as("a"),
        (rand(s + 5) * 1000).cast("int").as("b"))
    }
    val df1 = random(10000000, 34)
    val df2 = random(10000000, 17)
    
    // Move all the keys into a struct (to make handling nulls easy), deduplicate the given dataset
    // and count the rows per key.
    def dedup(df: Dataset[Row]): Dataset[Row] = {
      df.select(struct(df.columns.map(col): _*).as("key"))
        .groupBy($"key")
        .agg(count(lit(1)).as("row_count"))
    }
    
    // Deduplicate the inputs and join them using a full outer join. The result can contain
    // the following things:
    // 1. Both keys are not null (and thus equal), and the row counts are the same. The dataset
    //    is the same for the given key.
    // 2. Both keys are not null (and thus equal), and the row counts are not the same. The dataset
    //    contains the same keys.
    // 3. Only the right key is not null.
    // 4. Only the left key is not null.
    val joined = dedup(df1).as("l").join(dedup(df2).as("r"), $"l.key" === $"r.key", "full")
    
    // Summarize the differences.
    val summary = joined.select(
      count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" === $"l.row_count", 1)).as("left_right_same_rc"),
      count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" =!= $"l.row_count", 1)).as("left_right_different_rc"),
      count(when($"l.key".isNotNull && $"r.key".isNull, 1)).as("left_only"),
      count(when($"l.key".isNull && $"r.key".isNotNull, 1)).as("right_only"))
    summary.show()
    

    【讨论】:

      【解决方案5】:

      Java:

      assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();
      

      【讨论】:

      • 有趣的解决方案,但我相信这不能正确处理重复行。例如(在 Python 中):a = spark.createDataFrame([(1,), (1,)], schema='id int'); b = spark.createDataFrame([(1,)], schema='id int'); assert a.union(b).distinct().count() == a.intersect(b).count(); assert 在应该失败的地方成功。
      • try { return ds1.union(ds2) .groupBy(columns(ds1, ds1.columns())) .count() .filter("count % 2 > 0") .count() == 0; } catch (Exception e) { return false; } where columns 方法返回 Seq 或 Column[]
      【解决方案6】:

      尝试执行以下操作:

      df1.except(df2).isEmpty
      

      【讨论】:

      • 这在df2 大于df1 的情况下不起作用。也许如果你通过添加 && df2.except(df1).isEmpty... 使其对称
      • 即使你用每种方式比较它仍然不正确,因为 df2 中的重复行与 df1 中的一行匹配,反之亦然。
      【解决方案7】:
      try {
        return ds1.union(ds2)
                .groupBy(columns(ds1, ds1.columns()))
                .count()
                .filter("count % 2 > 0")
                .count()
            == 0;
      } catch (Exception e) {
        return false;
      }
      
      Column[] columns(Dataset<Row> ds, String... columnNames) {
      List<Column> l = new ArrayList<>();
      for (String cn : columnNames) {
        l.add(ds.col(cn));
      }
      return l.stream().toArray(Column[]::new);}
      

      columns 方法是补充,可以替换为任何返回 Seq 的方法

      逻辑:

      1. 合并两个数据集,如果列不匹配,则会抛出异常并因此返回 false。
      2. 如果列匹配,则对所有列进行 groupBy 并添加列计数。现在,所有行的计数都是 2 的倍数(即使是重复行)。
      3. 检查是否有任何行的计数不能被 2 整除,这些是多余的行。

      【讨论】:

      • 有人可以确认这个联合解决方案与上面提供的连接解决方​​案相比是否具有更好的性能? (而且它也适用于重复的行)
      • 不幸的是,这不正确,如果其中一个数据集的不同行重复了两次,您将得到误报。
      【解决方案8】:

      一种可扩展且简单的方法是区分两个DataFrames 并计算不匹配的行数:

      df1.diff(df2).where($"diff" != "N").count
      

      如果该数字不为零,则两个 DataFrames 不相等。

      diff 转换由spark-extension 提供。

      它识别I插入、C挂起、D删除和uN更改的行。 p>

      【讨论】:

      • 这是否比上面使用 collect() 的 PySpark 解决方案更具可扩展性?特别是如果您不需要差异列表?
      • 如果您的意思是df1.collect() != df2.collect() PySpark 解决方案,这根本不可扩展。两个 DataFrame 都加载到驱动程序的内存中。上面的diff 转换随集群扩展,这意味着如果您的集群可以处理 DataFrame,它就可以处理差异。所以答案是:是的。
      【解决方案9】:

      根据您是否有重复行,有 4 个选项。

      假设我们有两个DataFrames,z1 和 z1。选项 1/2 适用于 没有 重复的行。你可以在spark-shell试试这些。

      • 选项 1:直接执行除外
      import org.apache.spark.sql.DataFrame
      import org.apache.spark.sql.Column
      
      def isEqual(left: DataFrame, right: DataFrame): Boolean = {
         if(left.columns.length != right.columns.length) return false // column lengths don't match
         if(left.count != right.count) return false // record count don't match
         return left.except(right).isEmpty && right.except(left).isEmpty
      }
      
      • 选项 2:按列生成行哈希
      def createHashColumn(df: DataFrame) : Column = {
         val colArr = df.columns
         md5(concat_ws("", (colArr.map(col(_))) : _*))
      }
      
      val z1SigDF = z1.select(col("index"), createHashColumn(z1).as("signature_z1"))
      val z2SigDF = z2.select(col("index"), createHashColumn(z2).as("signature_z2"))
      val joinDF = z1SigDF.join(z2SigDF, z1SigDF("index") === z2SigDF("index")).where($"signature_z1" =!= $"signature_z2").cache
      // should be 0
      joinDF.count
      
      • 选项 3:使用 GroupBy(对于具有重复行的 DataFrame)
      val z1Grouped = z1.groupBy(z1.columns.map(c => z1(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")
      val z2Grouped = z2.groupBy(z2.columns.map(c => z2(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")
      
      val inZ1NotInZ2 = z1Grouped.except(z2Grouped).toDF()
      val inZ2NotInZ1 = z2Grouped.except(z1Grouped).toDF()
      // both should be size 0
      inZ1NotInZ2.show
      inZ2NotInZ1.show
      
      • 选项 4,使用 exceptAll,它也适用于具有重复行的数据
      // Source Code: https://github.com/apache/spark/blob/50538600ec/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2029
      val inZ1NotInZ2 = z1.exceptAll(z2).toDF()
      val inZ2NotInZ1 = z2.exceptAll(z1).toDF()
      // same here, // both should be size 0
      inZ1NotInZ2.show
      inZ2NotInZ1.show
      

      【讨论】:

      • Re:选项 2,concat 不适用于所有列类型,md5 可能在大数据上发生冲突。很好地添加了带有 exceptAll 的选项 4,它仅在 2.4.0 中添加。
      猜你喜欢
      • 1970-01-01
      • 2016-05-20
      • 2016-08-19
      • 2015-10-05
      • 1970-01-01
      相关资源
      最近更新 更多