【发布时间】:2023-10-01 13:46:01
【问题描述】:
对于我的单元测试,我想检查两个数组是否相同。简化示例:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
if np.all(a==b):
print 'arrays are equal'
这不起作用,因为nan != nan。
最好的方法是什么?
【问题讨论】:
标签: python numpy nan equality-operator
对于我的单元测试,我想检查两个数组是否相同。简化示例:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
if np.all(a==b):
print 'arrays are equal'
这不起作用,因为nan != nan。
最好的方法是什么?
【问题讨论】:
标签: python numpy nan equality-operator
您也可以将numpy.testing.assert_equal 或numpy.testing.assert_array_equal 与try/except 一起使用:
In : import numpy as np
In : def nan_equal(a,b):
...: try:
...: np.testing.assert_equal(a,b)
...: except AssertionError:
...: return False
...: return True
In : a=np.array([1, 2, np.NaN])
In : b=np.array([1, 2, np.NaN])
In : nan_equal(a,b)
Out: True
In : a=np.array([1, 2, np.NaN])
In : b=np.array([3, 2, np.NaN])
In : nan_equal(a,b)
Out: False
编辑
由于您将其用于单元测试,因此裸 assert(而不是包装它以获取 True/False)可能更自然。
【讨论】:
np.testing.assert_equal(a,b),如果它引发了异常,则测试失败(没有错误),我什至得到了一个带有差异和不匹配的漂亮打印。谢谢。
numpy.testing.assert_* 不遵循 python assert 的相同语义。在纯 Python AssertionError 中,如果 __debug__ is True 引发异常,即如果脚本运行未优化(无 -O 标志),请参阅 docs。出于这个原因,我强烈反对使用 AssertionErrors 进行流量控制。当然,由于我们在测试套件中,最好的解决方案是不理会 numpy.testing.assert。
numpy.testing.assert_equal() 的文档没有明确指出它认为 NaN 等于 NaN(而 numpy.testing.assert_array_equal() 确实如此):它在其他地方记录了吗?
nan = nan?我得到一个 AssertionError: Arrays are not equal 即使数组是相同的,包括 dtype。
对于 1.19 之前的 numpy 版本,在不专门涉及单元测试的情况下,这可能是最好的方法:
>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
True
但是,现代版本为 array_equal 函数提供了一个新的关键字参数 equal_nan,这完全符合要求。
这是由flyingdutchman首先指出的;详情见下方his answer。
【讨论】:
%timeit 进行测试产生 23.7 µs 与 1.01 ms。
最简单的方法是使用numpy.allclose() 方法,它允许指定具有nan 值时的行为。那么您的示例将如下所示:
a = np.array([1, 2, np.nan])
b = np.array([1, 2, np.nan])
if np.allclose(a, b, equal_nan=True):
print('arrays are equal')
然后arrays are equal将被打印出来。
你可以找到here相关文档
【讨论】:
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
bool 而不是引发 AssertionError。我需要这个来实现具有数组属性的类的__eq__(...)。
rtol=0, atol=0 以避免它认为接近值相等的问题(如@senderle 所述)。所以:np.allclose(a, b, equal_nan=True, rtol=0, atol=0).
您可以使用 numpy 屏蔽数组,屏蔽 NaN 值,然后使用 numpy.ma.all 或 numpy.ma.allclose:
http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.all.html
http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.allclose.html
例如:
a=np.array([1, 2, np.NaN])
b=np.array([1, 2, np.NaN])
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True
【讨论】:
np.ma.masked_where(np.isnan(a), a) 否则您无法比较无限值。
a=np.array([1, 2, np.NaN])和b=np.array([1, np.NaN, 2])进行了测试,它们显然不相等,np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b))仍然返回True,所以如果你使用这种方法,请注意这一点。
只是为了完成@Luis Albert Centeno’s answer,你可能宁愿使用:
np.allclose(a, b, rtol=0, atol=0, equal_nan=True)
rtol 和 atol 控制相等测试的容差。简而言之,allclose() 返回:
all(abs(a - b) <= atol + rtol * abs(b))
默认情况下,它们未设置为 0,因此如果您的数字接近但不完全相等,该函数可能会返回 True。
PS: "我想检查两个数组是否相同 " >> 实际上,您正在寻找平等而不是身份。它们在 Python 中是不一样的,我认为最好让每个人都了解它们的区别,以便共享相同的词典。 (https://www.blog.pythonlibrary.org/2017/02/28/python-101-equality-vs-identity/)
您将通过关键字is 测试身份:
a is b
【讨论】:
当我使用上述答案时:
((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
在评估字符串列表时它给了我一些错误。
这是更通用的类型:
def EQUAL(a,b):
return ((a == b) | ((a != a) & (b != b)))
【讨论】:
numpy 函数array_equal 完全符合问题的要求。当被问到时,它很可能不存在。 该示例如下所示:
a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
assert np.array_equal(a, b, equal_nan=True)
但请注意,如果元素的 dtype 为 object,这将不起作用。不确定这是不是bug。
【讨论】:
截至v1.9,numpy 的array_equal 函数支持equal_nan 参数:
assert np.array_equal(a, b, equal_nan=True)
【讨论】:
如果您为单元测试之类的事情这样做, 这样您就不太关心所有类型的性能和“正确”行为,您可以使用它获得一些有效的东西包含所有类型的数组,而不仅仅是数字:
a = np.array(['a', 'b', None])
b = np.array(['a', 'b', None])
assert list(a) == list(b)
将ndarrays 转换为lists 有时对于在某些测试中获得您想要的行为很有用。 (但不要在生产代码或更大的数组中使用它!)
【讨论】:
对我来说这很好用:
a = numpy.array(float('nan'), 1, 2)
b = numpy.array(2, float('nan'), 2)
numpy.equal(a, b, where =
numpy.logical_not(numpy.logical_or(
numpy.isnan(a),
numpy.isnan(b)
))
).all()
PS。有 nan 时忽略比较
【讨论】: