【问题标题】:Is there a way to get Python all() function to work with multi-dimensional arrays?有没有办法让 Python all() 函数与多维数组一起工作?
【发布时间】:2019-10-18 03:08:47
【问题描述】:

我正在尝试为一个基类实现一个通用且灵活的 __eq__ 方法,该方法将与尽可能多的对象类型一起使用,包括可迭代对象和 numpy 数组。

这是我目前所拥有的:

class Environment:

    def __init__(self, state):
        self.state = state

    def __eq__(self, other):
        """Compare two environments based on their states.
        """
        if isinstance(other, self.__class__):
            try:
                return all(self.state == other.state)
            except TypeError:
                return self.state == other.state
        return False

这适用于大多数对象类型,包括一维数组:

s = 'abcdef'
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

s = [[1, 2, 3], [4, 5, 6]]
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

s = np.array(range(6))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

问题是,当self.state 是一个多维的numpy 数组时,它会返回一个ValueError。

s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2

生产:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

显然,我可以检查isinstance(other, np.ndarray),然后检查(return self.state == other.state).all(),但我只是认为可能有一种更通用的方法可以用一个语句来处理任何类型的所有可迭代对象、集合和数组。

我也有点困惑为什么all() 不会像array.all() 那样遍历数组的所有元素。有没有办法触发np.nditer 并这样做?

【问题讨论】:

    标签: python arrays numpy equality


    【解决方案1】:
    Signature: all(iterable, /)
    Docstring:
    Return True if bool(x) is True for all values x in the iterable.
    

    对于一维数组:

    In [200]: x=np.ones(3)                                                               
    In [201]: x                                                                          
    Out[201]: array([1., 1., 1.])
    In [202]: y = x==x                                                                   
    In [203]: y          # 1d array of booleans                                                                      
    Out[203]: array([ True,  True,  True])
    In [204]: bool(y[0])                                                                 
    Out[204]: True
    In [205]: all(y)                                                                     
    Out[205]: True
    

    对于二维数组:

    In [206]: x=np.ones((2,3))                                                           
    In [207]: x                                                                          
    Out[207]: 
    array([[1., 1., 1.],
           [1., 1., 1.]])
    In [208]: y = x==x                                                                   
    In [209]: y                                                                          
    Out[209]: 
    array([[ True,  True,  True],
           [ True,  True,  True]])
    In [210]: y[0]                                                                       
    Out[210]: array([ True,  True,  True])
    In [211]: bool(y[0])                                                                 
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-211-d0ce0868392c> in <module>
    ----> 1 bool(y[0])
    
    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    但是对于不同的二维数组:

    In [212]: x=np.ones((3,1))                                                           
    In [213]: y = x==x                                                                   
    In [214]: y                                                                          
    Out[214]: 
    array([[ True],
           [ True],
           [ True]])
    In [215]: y[0]                                                                       
    Out[215]: array([ True])
    In [216]: bool(y[0])                                                                 
    Out[216]: True
    In [217]: all(y)                                                                     
    Out[217]: True
    

    numpy 数组的迭代沿第一个维度进行。 [i for i in x]

    只要在需要标量布尔值的上下文中使用多值布尔数组时,就会引发此歧义 ValueError。 ifor/and 表达式是常见的。

    In [223]: x=np.ones((2,3))                                                           
    In [224]: y = x==x                                                                   
    In [225]: np.all(y)                                                                  
    Out[225]: True
    

    np.all 与 Python all 的不同之处在于它“知道”尺寸。在这种情况下,它会使用ravel 将数组视为 1d:

    默认 (axis = None) 是对输入数组的所有维度执行逻辑与。

    【讨论】:

    • 谢谢。所以我猜没有单一的解决方案。
    • numpy 比较逐个元素工作的事实意味着比较与列表或元组根本不同。
    【解决方案2】:

    这不是我希望的简洁解决方案,而且可能效率低下,但我认为它适用于任何 n 维可迭代对象:

    def nd_true(nd_object):
        try:
            iterator = iter(nd_object)
        except TypeError:
            return nd_object
        else:
            return all([nd_true(x) for x in iterator])
    
    class Environment:
    
        def __init__(self, state):
            self.state = state
    
        def __eq__(self, other):
            """Compare two environments based on their states.
            """
            if isinstance(other, self.__class__):
                return nd_true(self.state == other.state)
            return False
    
    # Tests    
    s = 'abcdef'
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    
    s = [[1, 2, 3], [4, 5, 6]]
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    
    s = np.array(range(6))
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    
    s = np.array(range(6)).reshape((2, 3))
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    
    s = np.array(range(27)).reshape((3, 3, 3))
    e1 = Environment(s)
    e2 = Environment(s)
    
    e1 == e2  # True
    

    【讨论】:

      猜你喜欢
      • 2020-03-09
      • 1970-01-01
      • 2020-12-07
      • 2013-03-27
      • 2010-09-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-07-29
      相关资源
      最近更新 更多