【问题标题】:Check if two numpy arrays are identical检查两个numpy数组是否相同
【发布时间】:2023-04-07 22:23:02
【问题描述】:

假设我有一堆数组,包括xy,我想检查它们是否相等。一般来说,我可以只使用np.all(x == y)(除非我现在忽略了一些愚蠢的角落案例)。

不过,这会评估(x == y)整个 数组,这通常是不需要的。我的数组真的很大,而且我有很多,而且两个数组相等的概率很小,所以很可能,我真的只需要在all函数之前评估(x == y)的一小部分可能返回 False,所以这对我来说不是最佳解决方案。

我尝试使用内置的all 函数,结合itertools.izipall(val1==val2 for val1,val2 in itertools.izip(x, y))

但是,在两个数组相等的情况下,这似乎要慢得多,总的来说,它仍然不值得在np.all 上使用。我推测是因为内置 all 的通用性。而np.all 不适用于生成器。

有没有办法以更快的方式做我想做的事?

我知道这个问题类似于之前提出的问题(例如 Comparing two numpy arrays for equality, element-wise),但它们具体不包括提前终止的情况。

【问题讨论】:

  • @Thomas:那个函数只是在内部调用np.all,所以它有点没用。 (我确实希望有一个专门为此目的的函数来进行短路,但可惜它没有。)
  • 嗯,真可惜。我猜,一个 numpy-internal 函数将是你唯一的机会,因为任何在 numpy 之外的循环几乎都会变慢。您是否考虑过直接联系开发人员?

标签: python numpy


【解决方案1】:

在 numpy 原生实现之前,您可以编写自己的函数并使用 numba 对其进行 jit 编译:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def arrays_equal(a, b):
    if a.shape != b.shape:
        return False
    for ai, bi in zip(a.flat, b.flat):
        if ai != bi:
            return False
    return True


a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)


%timeit np.all(a==b)  # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a)  # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b)  # 100000 loops, best of 3: 691 ns per loop

最坏情况下的性能(数组相等)相当于np.all,并且在提前停止编译函数的情况下有可能大大优于np.all

【讨论】:

  • 我喜欢它,但是对于我的测试数组,如果它们相等,它仍然需要比np.all(arr1==arr2) 长约 1.6 倍。 (作为参考,arr1 = np.ones((1000000,), dtype=bool), 'arr2 = np.ones((1000000,), dtype=bool)', 'arr2[100000] = False`)。 (确保将 timeit 上的数字降低到 1000 左右。)
  • @acdr 当我使用您的阵列时,np.all 需要 1.8 毫秒,arrays_equal 需要 183 微秒。如果我将arr1 与自身进行比较,两者都需要大约 1.8 毫秒。也许这种差异是由我们的系统造成的?我有 Python 3.5.2、numpy 1.12.1 和 numba 0.27.0。
  • 可能是。一般来说,我运行的东西比你老得多:Python 2.7.10.2、numpy 1.9.1、numba 0.20.0
  • Np.all 没有分支指令。在数组相同的情况下,您希望没有分支的函数比有分支的函数更快。这可能就是差异的来源。您应该查看您的用例并确定更有可能发生的情况。这仍然是 python,而不是汇编,所以微优化可能并不总是有你想要的效果。
  • @JannPoppinga 不幸的是,不;至少不在这个函数中。 array_equal 只是calls np.all(a==b) internally
【解决方案2】:

numpy page on github 显然正在讨论向数组比较添加短路逻辑,因此可能会在未来版本的 numpy 中提供。

【讨论】:

    【解决方案3】:

    嗯,我知道这是一个糟糕的答案,但似乎没有简单的方法。 Numpy Creators 应该修复它。我建议:

    def compare(a, b):
        if len(a) > 0 and not np.array_equal(a[0], b[0]):
            return False
        if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
            return False
        if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
            return False
        return np.array_equal(a, b)
    

    :)

    【讨论】:

    • 因为没有人说过,用numpy做不到,我认为问题仍然存在
    • This answer 已被接受并使用 numpy
    • 它使用 numba。如果你对某人诚实地告诉你没有更好的方法来做某事不满意,你可以举报它,但我的回答至少包含创造性的解决方案。
    【解决方案4】:

    嗯,这不是一个真正的答案,因为我没有检查它是否断路,但是:

    assert_array_equal.

    来自文档:

    如果两个 array_like 对象不相等,则引发 AssertionError。

    TryExcept 如果不在性能敏感的代码路径上。

    或者按照底层源码,或许效率更高。

    【讨论】:

    • 感谢您的建议。不幸的是,底层代码看起来只是 x == y 的包装,为一些极端情况(如 NaNs 和 Infs)添加了一些额外的步骤。
    【解决方案5】:

    您可以迭代数组的所有元素并检查它们是否相等。 如果数组很可能不相等,它的返回速度将比 .all 函数快得多。 像这样的:

    import numpy as np
    
    a = np.array([1, 2, 3])
    b = np.array([1, 3, 4])
    
    areEqual = True
    
    for x in range(0, a.size-1):
            if a[x] != b[x]:
                    areEqual = False
                    break
            else:
                   print "a[x] is equal to b[x]\n"
    
    if areEqual:
            print "The tables are equal\n"
    else:
            print "The tables are not equal\n"
    

    【讨论】:

    • 这实际上是all(val1==val2 for val1,val2 in itertools.izip(x, y)) 所做的:它循环通过xy,返回成对的val1val2,检查它们是否相同,然后通过结果到all,一旦找到不相等的对,它将立即返回False
    • 哦,我明白了,我以为它会遍历数组的所有元素。
    • 幸运的是,内置的all 确实会进行熔断,这与np.all 不同。 :)
    【解决方案6】:

    可能了解底层数据结构的人可以对此进行优化或解释它是否可靠/安全/良好的做法,但它似乎有效。

    np.all(a==b)
    Out[]: True
    
    memoryview(a.data)==memoryview(b.data)
    Out[]: True
    
    %timeit np.all(a==b)
    The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
    100000 loops, best of 3: 6.2 µs per loop
    
    %timeit memoryview(a.data)==memoryview(b.data)
    The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
    100000 loops, best of 3: 1.85 µs per loop
    

    如果我理解正确,ndarray.data 创建一个指向数据缓冲区的指针,memoryview 创建一个可以从缓冲区短路的本机 python 类型。

    我认为。

    编辑:进一步的测试表明它可能没有如图所示的时间改进那么大。以前a=b=np.eye(5)

    a=np.random.randint(0,10,(100,100))
    
    b=a.copy()
    
    %timeit np.all(a==b)
    The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached.
    10000 loops, best of 3: 17.7 µs per loop
    
    %timeit memoryview(a.data)==memoryview(b.data)
    10000 loops, best of 3: 30.1 µs per loop
    
    np.all(a==b)
    Out[]: True
    
    memoryview(a.data)==memoryview(b.data)
    Out[]: True
    

    【讨论】:

    • 这不只是测试两个数组是否实际上是同一个对象的不同名称,而不是具有相同值的两个不同对象吗?
    • 据我所知没有。如上所述使用.copy() 进行测试,然后以相同的方式依次操作上面的两个随机数组。
    • 对我不起作用,使用当前 Anaconda 版本的 numpy。也许它只是不喜欢 NaN。
    • @matanster 不确定您尝试了什么,但在标准用法中 NaN != NaN
    【解决方案7】:

    正如 Thomas Kühn 在对您的帖子的评论中所写,array_equal 是一个可以解决问题的函数。在Numpy's API reference 中有描述。

    【讨论】:

      猜你喜欢
      • 2019-09-04
      • 2012-05-31
      • 2021-11-15
      • 2020-02-12
      • 2014-02-03
      • 2017-09-09
      • 2014-12-06
      相关资源
      最近更新 更多