【问题标题】:How to vectorize this for loops in python?如何在python中对这个for循环进行矢量化?
【发布时间】:2019-05-30 08:54:02
【问题描述】:

代码如下

import numpy as np


data = np.random.randint(0, 10, 12).reshape(3, 4)
print(data)

h, w = data.shape[:2]
dataMask = np.zeros((h, w, 10), np.int)

r = 2

for i in range(h):
    for j in range(w):
        for ir in range(i - r, i + r):
            for jr in range(j - r, j + r):
                if ir >= 0 and ir < h and jr >= 0 and jr < w:
                    dataMask[i, j, data[ir, jr]] += 1

print(dataMask)

我有一个形状为 (h, w) 的 numpy 数组“数据”。它的元素是int number ∈[0, 10)。
我创建了一个形状为 (h, w, 10) 的数组 dataMask。 dataMask[i, j, k] 表示数据区域内值为 k 的点的数量。数据中这个区域的中心是(i,j),r=2,是一个正方形。

如何向量化代码中的那些 for 循环?谢谢!

【问题讨论】:

    标签: python numpy vectorization


    【解决方案1】:

    这是使用cumsum的一种方法:

    import numpy as np
    
    
    data = np.random.randint(0, 10, 1200).reshape(30, 40)
    print(data)
    
    h, w = data.shape[:2]
    dataMask = np.zeros((h, w, 10), np.int)
    
    r = 20
    
    from time import time
    T = []
    
    T.append(time())
    
    for i in range(h):
        for j in range(w):
            for ir in range(i - r, i + r):
                for jr in range(j - r, j + r):
                    if ir >= 0 and ir < h and jr >= 0 and jr < w:
                        dataMask[i, j, data[ir, jr]] += 1
    
    T.append(time())
    
    m1 = np.zeros((h, w, 10), np.int)
    np.put_along_axis(m1, data[...,None], 1, 2)
    m2 = np.empty_like(m1)
    m1 = m1.cumsum(1)
    m2[: ,:-r+1] = m1[:, r-1:]
    m2[:, -r+1:] = m1[:, -1, None]
    m2[:, r+1:] -= m1[:, :-r-1]
    m2 = m2.cumsum(0)
    m1[:-r+1] = m2[r-1:]
    m1[-r+1:] = m2[-1, None]
    m1[r+1:] -= m2[:-r-1]
    
    T.append(time())
    
    
    assert (dataMask==m1).all()
    
    print(np.diff(T))
    

    使用h,w,r = 30,40,20 运行示例

    # time [seconds] used by
    # OP            cumsum
    [9.23162699e-01 3.41892242e-04]
    

    【讨论】:

      【解决方案2】:

      这是一个“部分矢量化”的解决方案,它只迭代窗口大小。

      import numpy as np
      from itertools import product
      
      # Input data
      np.random.seed(0)
      data = np.random.randint(0, 10, 12).reshape(3, 4)
      h, w = data.shape[:2]
      dataMask = np.zeros((h, w, 10), np.int)
      r = 2
      
      # Original solution
      for i in range(h):
          for j in range(w):
              for ir in range(i - r, i + r):
                  for jr in range(j - r, j + r):
                      if ir >= 0 and ir < h and jr >= 0 and jr < w:
                          dataMask[i, j, data[ir, jr]] += 1
      
      # Partially vectorized solution
      idx_i, idx_j = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
      idx_i = idx_i.ravel()
      idx_j = idx_j.ravel()
      idx_k = data.ravel()
      dataMask2 = np.zeros((h, w, 10), np.int)
      for i, j in product(range(-r + 1, r + 1), repeat=2):
          ii = idx_i + i
          jj = idx_j + j
          m = (ii >= 0) & (ii < h) & (jj >= 0) & (jj < w)
          ii = ii[m]
          jj = jj[m]
          kk = idx_k[m]
          np.add.at(dataMask2, (ii, jj, kk), 1)
      
      print(np.all(dataMask == dataMask2))
      # True
      

      您实际上可以通过更多地平铺数据(这会使用更多内存)来使其完全矢量化:

      import numpy as np
      
      # Fully vectorized
      idx_i, idx_j = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
      w_i, w_j = np.meshgrid(np.arange(-r + 1, r + 1), np.arange(-r + 1, r + 1), indexing='ij')
      ii = (idx_i[:, :, np.newaxis, np.newaxis] + w_i).ravel()
      jj = (idx_j[:, :, np.newaxis, np.newaxis] + w_j).ravel()
      kk = np.tile(data[:, :, np.newaxis, np.newaxis], (1, 1, 2 * r, 2 * r)).ravel()
      m = (ii >= 0) & (ii < h) & (jj >= 0) & (jj < w)
      ii = ii[m]
      jj = jj[m]
      kk = kk[m]
      dataMask3 = np.zeros((h, w, 10), np.int)
      np.add.at(dataMask3, (ii, jj, kk), 1)
      print(np.all(dataMask == dataMask3))
      # True
      

      【讨论】:

        猜你喜欢
        • 2016-05-26
        • 2015-06-25
        • 1970-01-01
        • 1970-01-01
        • 2015-09-04
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多