【问题标题】:Comparing numpy arrays containing NaN比较包含 NaN 的 numpy 数组
【发布时间】:2023-10-01 13:46:01
【问题描述】:

对于我的单元测试,我想检查两个数组是否相同。简化示例:

a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])

if np.all(a==b):
    print 'arrays are equal'

这不起作用,因为nan != nan。 最好的方法是什么?

【问题讨论】:

    标签: python numpy nan equality-operator


    【解决方案1】:

    您也可以将numpy.testing.assert_equalnumpy.testing.assert_array_equaltry/except 一起使用:

    In : import numpy as np
    
    In : def nan_equal(a,b):
    ...:     try:
    ...:         np.testing.assert_equal(a,b)
    ...:     except AssertionError:
    ...:         return False
    ...:     return True
    
    In : a=np.array([1, 2, np.NaN])
    
    In : b=np.array([1, 2, np.NaN])
    
    In : nan_equal(a,b)
    Out: True
    
    In : a=np.array([1, 2, np.NaN])
    
    In : b=np.array([3, 2, np.NaN])
    
    In : nan_equal(a,b)
    Out: False
    

    编辑

    由于您将其用于单元测试,因此裸 assert(而不是包装它以获取 True/False)可能更自然。

    【讨论】:

    • 非常好,这是最优雅的内置解决方案。我刚刚在我的单元测试中添加了np.testing.assert_equal(a,b),如果它引发了异常,则测试失败(没有错误),我什至得到了一个带有差异和不匹配的漂亮打印。谢谢。
    • 请注意,此解决方案有效,因为 numpy.testing.assert_* 不遵循 python assert 的相同语义。在纯 Python AssertionError 中,如果 __debug__ is True 引发异常,即如果脚本运行未优化(无 -O 标志),请参阅 docs。出于这个原因,我强烈反对使用 AssertionErrors 进行流量控制。当然,由于我们在测试套件中,最好的解决方案是不理会 numpy.testing.assert。
    • numpy.testing.assert_equal() 的文档没有明确指出它认为 NaN 等于 NaN(而 numpy.testing.assert_array_equal() 确实如此):它在其他地方记录了吗?
    • @EricOLebigot numpy.testing.assert_equal() 是否依赖考虑nan = nan?我得到一个 AssertionError: Arrays are not equal 即使数组是相同的,包括 dtype。
    • current 官方文档和上面的例子都表明它确实认为 NaN == NaN。我认为最好的办法是让您提出一个包含详细信息的新 * 问题。
    【解决方案2】:

    对于 1.19 之前的 numpy 版本,在不专门涉及单元测试的情况下,这可能是最好的方法:

    >>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
    True
    

    但是,现代版本为 array_equal 函数提供了一个新的关键字参数 equal_nan,这完全符合要求。

    这是由flyingdutchman首先指出的;详情见下方his answer

    【讨论】:

    • +1 这个解决方案似乎比我使用掩码数组发布的解决方案快一点,但如果您要创建掩码以用于代码的其他部分,则创建掩码的开销会成为影响 ma 策略整体效率的因素。
    • 谢谢。您的解决方案确实有效,但我更喜欢 Avaris 建议的 numpy 中的内置测试
    • 我真的很喜欢它的简单性。此外,它似乎比@Avaris 解决方案更快。将其转换为 lambda 函数,使用 Ipython 的 %timeit 进行测试产生 23.7 µs 与 1.01 ms。
    • @NovicePhysicist,有趣的时机!我想知道它是否与使用异常处理有关。您是否测试了阳性与阴性结果?根据是否抛出异常,速度可能会有很大差异。
    • 不,只是做了一个简单的测试,一些与我手头的问题相关的广播(比较了二维数组和一维向量——所以我猜这是逐行比较)。但我想一个人可以很容易地在 Ipython 笔记本上做很多测试。另外,我为您的解决方案使用了 lambda 函数,但我认为如果我使用常规函数(通常似乎是这种情况),它应该会快一点。
    【解决方案3】:

    最简单的方法是使用numpy.allclose() 方法,它允许指定具有nan 值时的行为。那么您的示例将如下所示:

    a = np.array([1, 2, np.nan])
    b = np.array([1, 2, np.nan])
    
    if np.allclose(a, b, equal_nan=True):
        print('arrays are equal')
    

    然后arrays are equal将被打印出来。

    你可以找到here相关文档

    【讨论】:

    • +1 因为您的解决方案不会重新发明*。但是,这仅适用于类似数字的项目。否则,你会得到讨厌的TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
    • 这在许多情况下都是一个很好的答案!值得一提的是,即使数组不严格相等,这也会返回 true。不过很多时候这并不重要。
    • +1,因为这会返回 bool 而不是引发 AssertionError。我需要这个来实现具有数组属性的类的__eq__(...)
    • 就像指向稍后答案的指针:*.com/a/58709110/1207489。添加rtol=0, atol=0 以避免它认为接近值相等的问题(如@senderle 所述)。所以:np.allclose(a, b, equal_nan=True, rtol=0, atol=0).
    【解决方案4】:

    您可以使用 numpy 屏蔽数组,屏蔽 NaN 值,然后使用 numpy.ma.allnumpy.ma.allclose

    http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.all.html

    http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.allclose.html

    例如:

    a=np.array([1, 2, np.NaN])
    b=np.array([1, 2, np.NaN])
    np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True
    

    【讨论】:

    • 感谢您让我意识到掩码数组的使用。不过,我更喜欢 Avaris 的解决方案。
    • 您应该使用np.ma.masked_where(np.isnan(a), a) 否则您无法比较无限值。
    • 我用a=np.array([1, 2, np.NaN])b=np.array([1, np.NaN, 2])进行了测试,它们显然不相等,np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b))仍然返回True,所以如果你使用这种方法,请注意这一点。
    • 此方法只测试没有 NaN 值的两个数组是否相同,但不测试 NaN 是否出现在相同的地方......使用起来可能很危险。
    【解决方案5】:

    只是为了完成@Luis Albert Centeno’s answer,你可能宁愿使用:

    np.allclose(a, b, rtol=0, atol=0, equal_nan=True)
    

    rtolatol 控制相等测试的容差。简而言之,allclose() 返回:

    all(abs(a - b) <= atol + rtol * abs(b))
    

    默认情况下,它们未设置为 0,因此如果您的数字接近但不完全相等,该函数可能会返回 True


    PS: "我想检查两个数组是否相同 " >> 实际上,您正在寻找平等而不是身份。它们在 Python 中是不一样的,我认为最好让每个人都了解它们的区别,以便共享相同的词典。 (https://www.blog.pythonlibrary.org/2017/02/28/python-101-equality-vs-identity/)

    您将通过关键字is 测试身份:

    a is b
    

    【讨论】:

      【解决方案6】:

      当我使用上述答案时:

       ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
      

      在评估字符串列表时它给了我一些错误。

      这是更通用的类型:

      def EQUAL(a,b):
          return ((a == b) | ((a != a) & (b != b)))
      

      【讨论】:

        【解决方案7】:

        numpy 函数array_equal 完全符合问题的要求。当被问到时,它很可能不存在。 该示例如下所示:

        a = np.array([1, 2, np.NaN])
        b = np.array([1, 2, np.NaN])
        assert np.array_equal(a, b, equal_nan=True)
        

        但请注意,如果元素的 dtype 为 object,这将不起作用。不确定这是不是bug

        【讨论】:

          【解决方案8】:

          截至v1.9,numpy 的array_equal 函数支持equal_nan 参数:

          assert np.array_equal(a, b, equal_nan=True)
          

          【讨论】:

            【解决方案9】:

            如果您为单元测试之类的事情这样做, 这样您就不太关心所有类型的性能和“正确”行为,您可以使用它获得一些有效的东西包含所有类型的数组,而不仅仅是数字

            a = np.array(['a', 'b', None])
            b = np.array(['a', 'b', None])
            assert list(a) == list(b)
            

            ndarrays 转换为lists 有时对于在某些测试中获得您想要的行为很有用。 (但不要在生产代码或更大的数组中使用它!)

            【讨论】:

              【解决方案10】:

              对我来说这很好用:

              a = numpy.array(float('nan'), 1, 2)
              b = numpy.array(2, float('nan'), 2)
              numpy.equal(a, b, where = 
                  numpy.logical_not(numpy.logical_or(
                      numpy.isnan(a), 
                      numpy.isnan(b)
                  ))
              ).all()
              

              PS。有 nan 时忽略比较

              【讨论】: