【问题标题】:How to test if all rows are equal in a numpy如何测试numpy中的所有行是否相等
【发布时间】:2014-11-27 14:43:36
【问题描述】:

在 numpy 中,是否有一种很好的惯用方法来测试二维数组中的所有行是否相等?

我可以做类似的事情

np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])

这似乎将 python 列表与 numpy 数组混合在一起,这很丑,而且可能也很慢。

有更好/更整洁的方法吗?

【问题讨论】:

  • 正如我对similar question 所说的那样,这确实需要一个适当的解决方案,它不会创建一个与原始数组一样大的临时数组(这里和那里的答案都是如此)。添加到 numpy 后,我会发布答案。

标签: python arrays numpy


【解决方案1】:

只需检查数组中唯一项的数量是否为 1:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> len(np.unique(arr)) == 1
True

灵感来自 unutbu 的 answer 的解决方案:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> np.all(np.all(arr == arr[0,:], axis = 1))
True

您的代码的一个问题是,在应用np.all() 之前,您首先要创建一个完整的列表。由于您的版本中没有发生短路,因此如果您将 Python 的 all() 与生成器表达式一起使用会更好:

时间比较:

>>> M = arr = np.array([[3]*100] + [[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 272 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 596 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 10.6 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100000 loops, best of 3: 11.3 µs per loop

>>> M = arr = np.array([[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 330 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 594 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 9.51 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100 loops, best of 3: 9.44 ms per loop

【讨论】:

  • 我觉得ajcr的回答更快!
  • @user2179021 它在我的系统上占用650 µs,所以仍然比我的第二个答案慢。
  • 同意 - 第二种方法在我的系统上也比我的更快。
  • 您是否认为有一种类似的快速方法来判断所有行是否不同?
  • 您的第一种方法实际上是检查数组的所有项是否相同,而不是所有行是否相同...试试np.array([[1,2,3],[1,2,3]])
【解决方案2】:

一种方法是检查数组arr 的每一行是否等于它的第一行arr[0]

(arr == arr[0]).all()

对整数值使用相等== 很好,但如果arr 包含浮点值,您可以使用np.isclose 来检查给定容差内的相等性:

np.isclose(a, a[0]).all()

如果您的数组包含NaN,并且您想避免棘手的NaN != NaN 问题,您可以将此方法与np.isnan 结合使用:

(np.isclose(a, a[0]) | np.isnan(a)).all()

【讨论】:

  • 我认为这是最快的方法。谢谢。
  • 检查是否相等,而不是等于 0 的差异,可能会快一点。
  • (np.diff(b, 1, 1) == 0).all()怎么样
  • @Jaime - 感谢您的建议,我已将其编辑为答案。
  • 现在有np.allclose
【解决方案3】:

值得一提的是above version 不适用于多维数组。

例如:对于一个三维正方形图像张量img [256, 256, 3],我们需要检查图像中是否有相同的RGB [256, 256]层。 在这种情况下,我们需要使用broadcasting

(img == img[:, :, 0, np.newaxis]).all()

因为简单的img[:, :, 0] 给了我们[256, 256],但是我们需要[256, 256, 1] 来通过层广播。

【讨论】:

    猜你喜欢
    • 2011-10-29
    • 2014-11-27
    • 2013-03-10
    • 2018-08-11
    • 2013-07-15
    • 2014-12-27
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多