【问题标题】:Conditional operations on numpy arraysnumpy 数组的条件操作
【发布时间】:2017-07-21 06:45:50
【问题描述】:

我是 NumPy 的新手,在 numpy 数组上运行一些条件语句时遇到了问题。假设我有 3 个如下所示的 numpy 数组:

一个:

[[0, 4, 4, 2],
 [1, 3, 0, 2],
 [3, 2, 4, 4]]

b:

[[6, 9, 8, 6],
 [7, 7, 9, 6],
 [8, 6, 5, 7]]

和,c:

[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

我有一个a和b的条件语句,我想用b的值(如果a和b的条件都满足的话)来计算c的值:

c[(a > 3) & (b > 8)]+=b*2

我收到一条错误消息:

Traceback (most recent call last):
  File "<interactive input>", line 1, in <module>
ValueError: non-broadcastable output operand with shape (1,) doesn't match the broadcast shape (3,4)

知道我该如何做到这一点吗?

我希望 c 的输出如下所示:

[[0, 18, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

【问题讨论】:

  • 谢谢!为每个人点赞,因为他们都在工作。我接受了@Psidom 对 np.where 的回答,因为它对我来说最有意义并且运行时间最快(我的实际脚本将运行这些条件数百万次)

标签: python arrays numpy conditional


【解决方案1】:

问题在于您屏蔽了接收部分,但没有不屏蔽发送者部分。结果:

c[(a > 3) & (b > 8)]+=b*2
# ^ 1x1 matrix        ^3x4 matrix

尺寸不一样。鉴于您想要执行逐元素添加(基于您的示例),您也可以简单地将切片添加到正确的部分:

c[(a &gt; 3) &amp; (b &gt; 8)]+=b<b>[(a &gt; 3) &amp; (b &gt; 8)]</b>*2

或者让它更有效率:

mask = (a > 3) & (b > 8)
c[mask] += b[mask]*2

【讨论】:

    【解决方案2】:

    你可以使用numpy.where:

    np.where((a > 3) & (b > 8), c + b*2, c)
    #array([[ 0, 18,  0,  0],
    #       [ 0,  0,  0,  0],
    #       [ 0,  0,  0,  0]])
    

    或算术:

    c + b*2 * ((a > 3) & (b > 8))
    #array([[ 0, 18,  0,  0],
    #       [ 0,  0,  0,  0],
    #       [ 0,  0,  0,  0]])
    

    【讨论】:

      【解决方案3】:

      对 numpy 表达式稍作改动就会得到想要的结果:

      c += ((a > 3) & (b > 8)) * b*2
      

      首先,我从((a &gt; 3) &amp; (b &gt; 8)) 创建一个带有布尔值的掩码矩阵,然后将该矩阵与b*2 相乘,从而生成一个3x4 矩阵,该矩阵可以轻松添加到c

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2015-02-12
        • 2020-08-10
        • 2016-06-06
        • 2021-04-21
        • 2021-05-12
        • 2018-11-30
        相关资源
        最近更新 更多