【发布时间】:2019-10-18 03:08:47
【问题描述】:
我正在尝试为一个基类实现一个通用且灵活的 __eq__ 方法,该方法将与尽可能多的对象类型一起使用,包括可迭代对象和 numpy 数组。
这是我目前所拥有的:
class Environment:
def __init__(self, state):
self.state = state
def __eq__(self, other):
"""Compare two environments based on their states.
"""
if isinstance(other, self.__class__):
try:
return all(self.state == other.state)
except TypeError:
return self.state == other.state
return False
这适用于大多数对象类型,包括一维数组:
s = 'abcdef'
e1 = Environment(s)
e2 = Environment(s)
e1 == e2 # True
s = [[1, 2, 3], [4, 5, 6]]
e1 = Environment(s)
e2 = Environment(s)
e1 == e2 # True
s = np.array(range(6))
e1 = Environment(s)
e2 = Environment(s)
e1 == e2 # True
问题是,当self.state 是一个多维的numpy 数组时,它会返回一个ValueError。
s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s)
e1 == e2
生产:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
显然,我可以检查isinstance(other, np.ndarray),然后检查(return self.state == other.state).all(),但我只是认为可能有一种更通用的方法可以用一个语句来处理任何类型的所有可迭代对象、集合和数组。
我也有点困惑为什么all() 不会像array.all() 那样遍历数组的所有元素。有没有办法触发np.nditer 并这样做?
【问题讨论】:
标签: python arrays numpy equality