【问题标题】:Numpy: Can you use broadcasting to replace values by row?Numpy:你可以使用广播来逐行替换值吗?
【发布时间】:2018-07-21 23:28:40
【问题描述】:

我有一个 M x N 矩阵 X 和一个 1 x N 矩阵 Y。我想做的是根据 Y 的列将 X 中的任何 0 条目替换为 Y 中的适当值。

如果

X = np.array([[0, 1, 2], [3, 0, 5]])

Y = np.array([10, 20, 30])

所需的最终结果将是 [[10, 1, 2], [3, 20, 5]]。

这可以通过生成一个每行为 Y 的 M x N 矩阵然后使用过滤器数组来直接完成:

Y = np.ones((X.shape[0], 1)) * Y.reshape(1, -1)
X[X==0] = Y[X==0]

但是这可以使用 numpy 的广播功能来完成吗?

【问题讨论】:

    标签: python numpy array-broadcasting


    【解决方案1】:

    当然。不要物理重复Y,而是使用numpy.broadcast_to 创建Y 的广播视图,其形状为X

    expanded = numpy.broadcast_to(Y, X.shape)
    
    mask = X==0
    x[mask] = expanded[mask]
    

    【讨论】:

      【解决方案2】:

      扩展 X 使其更通用:

      In [306]: X = np.array([[0, 1, 2], [3, 0, 5],[0,1,0]])
      

      where标识0;第二个数组标识列

      In [307]: idx = np.where(X==0)
      In [308]: idx
      Out[308]: (array([0, 1, 2, 2]), array([0, 1, 0, 2]))
      
      
      In [309]: Z = X.copy()
      In [310]: Z[idx]
      Out[310]: array([0, 0, 0, 0])       # flat list of where to put the values
      In [311]: Y[idx[1]]
      Out[311]: array([10, 20, 10, 30])   # matching list of values by column
      
      In [312]: Z[idx] = Y[idx[1]]
      In [313]: Z
      Out[313]: 
      array([[10,  1,  2],
             [ 3, 20,  5],
             [10,  1, 30]])
      

      不做广播,但相当干净numpy


      broadcast_to 方法相比的次数

      In [314]: %%timeit 
           ...: idx = np.where(X==0)
           ...: Z[idx] = Y[idx[1]]
           ...: 
      9.28 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
      In [315]: %%timeit
           ...: exp = np.broadcast_to(Y,X.shape)
           ...: mask=X==0
           ...: Z[mask] = exp[mask]
           ...: 
      19.5 µs ± 513 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
      

      虽然样本量很小,但速度更快。

      制作expanded Y 的另一种方法是使用repeat

      In [319]: %%timeit
           ...: exp = np.repeat(Y[None,:],3,0)
           ...: mask=X==0
           ...: Z[mask] = exp[mask]
           ...: 
      10.8 µs ± 55.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
      

      谁的时间接近我的where。原来broadcast_to比较慢:

      In [321]: %%timeit
           ...: exp = np.broadcast_to(Y,X.shape)
           ...: 
      10.5 µs ± 52.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
      In [322]: %%timeit
           ...: exp = np.repeat(Y[None,:],3,0)
           ...: 
      3.76 µs ± 11.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
      

      我们必须做更多的测试,看看这是否只是由于设置成本,或者相对时间是否仍然适用于更大的阵列。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2016-09-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多