【问题标题】:numpy vectorization instead of for loopsnumpy 向量化而不是 for 循环
【发布时间】:2016-02-20 21:30:12
【问题描述】:

我用 Python 写了一些代码,运行良好但速度很慢;我认为由于 for 循环。我希望可以使用 numpy 命令加快以下操作。让我定义目标。

假设我有一个二维 numpy 数组 all_CMs,维度为 row x col。例如考虑一个6x11 数组(见下图)。

  1. 我想计算所有行的平均值,即 sumⱼ aᵢⱼ 得到一个数组。当然,这很容易做到。 (我把这个值称为CM_tilde

  2. 现在,对于 每一行,我想计算一些选定值的平均值,即所有低于某个阈值的值,方法是计算它们的总和并将其除以所有列的数量 ( N)。如果该值高于此定义的阈值,则添加 CM_tilde 值(整行的平均值)。这个值叫做CM

  3. 然后,从行中的每个元素中减去 CM

除此之外,我还想要一个 numpy 数组或列表,其中列出了所有 CM 值。

图:

以下代码可以运行,但速度很慢(尤其是在数组变大的情况下)

CM_tilde = np.mean(data, axis=1)
N = data.shape[1]
data_cm = np.zeros(( data.shape[0], data.shape[1], data.shape[2] ))
all_CMs = np.zeros(( data.shape[0], data.shape[2]))
for frame in range(data.shape[2]):
    for row in range(data.shape[0]):
        CM=0
        for col in range(data.shape[1]):
            if data[row, col, frame] < (CM_tilde[row, frame]+threshold):
               CM += data[row, col, frame]
            else:
               CM += CM_tilde[row, frame]
        CM = CM/N
        all_CMs[row, frame] = CM
        # calculate CM corrected value
        for col in range(data.shape[1]):
            data_cm[row, col, frame] = data[row, col, frame] - CM
    print "frame: ", frame
return data_cm, all_CMs

有什么想法吗?

【问题讨论】:

  • 在第 2 步中,您是否实质上将任何高于阈值的值替换为 CM_tilde,然后然后计算整行的平均值,包括替换的值?跨度>
  • 首先使用 np.where 替换您的内部 for 循环。然后,使用广播,您可以删除外部 2 个循环。请参阅where的文档

标签: python numpy vectorization


【解决方案1】:

矢量化你正在做的事情很容易:

import numpy as np

#generate dummy data
nrows=6
ncols=11
nframes=3
threshold=0.3
data=np.random.rand(nrows,ncols,nframes)

CM_tilde = np.mean(data, axis=1)
N = data.shape[1]

all_CMs2 = np.mean(np.where(data < (CM_tilde[:,None,:]+threshold),data,CM_tilde[:,None,:]),axis=1)
data_cm2 = data - all_CMs2[:,None,:]

将此与您的原件进行比较:

In [684]: (data_cm==data_cm2).all()
Out[684]: True

In [685]: (all_CMs==all_CMs2).all()
Out[685]: True

逻辑是我们同时处理大小为[nrows,ncols,nframes] 的数组。主要技巧是利用python的广播,将大小为[nrows,nframes]CM_tilde变成大小为[nrows,1,nframes]CM_tilde[:,None,:]。然后 Python 将为每一列使用相同的值,因为这是修改后的 CM_tilde 的单一维度。

通过使用np.where,我们选择(基于threshold)是要获取data的对应值,还是CM_tilde的广播值。 np.mean 的新用法允许我们计算 all_CMs2

在最后一步中,我们通过直接从data 的相应元素中减去这个新的all_CMs2 来利用广播。

通过查看临时变量的隐式索引,它可能有助于以这种方式对代码进行矢量化。我的意思是您的临时变量CM 存在于[nrows,nframes] 的循环中,并且每次迭代都会重置其值。这意味着CM 实际上是一个数量CM[row,frame](后来显式分配给二维数组all_CMs),从这里很容易看出,您可以通过将适当的CMtmp[row,col,frames] 数量与其相加来构造它列维度。如果有帮助,您可以为此将np.where(...) 部分命名为CMtmp,然后从中计算np.mean(CMtmp,axis=1)。显然,结果相同,但可能更透明。

【讨论】:

  • 非常感谢;与循环相比,这要快得多
  • 10001 对于代表来说是一个不错的价值,如果有人对此投反对票,那就太可惜了。
  • @BhargavRao \o/ 谢谢你,先生!:) 或者,谢谢你不投票:D
【解决方案2】:

这是我对你的函数的矢量化。我从内到外工作,并在我进行过程中注释掉早期版本。所以我矢量化的第一个循环有### 注释标记。

它不像@Andras's 的回答那样干净且有理有据,但希望它具有指导意义,让您了解如何逐步解决此问题。

def foo2(data, threshold):
    CM_tilde = np.mean(data, axis=1)
    N = data.shape[1]
    #data_cm = np.zeros(( data.shape[0], data.shape[1], data.shape[2] ))
    ##all_CMs = np.zeros(( data.shape[0], data.shape[2]))
    bmask = data < (CM_tilde[:,None,:] + threshold)
    CM = np.zeros_like(data)
    CM[:] = CM_tilde[:,None,:]
    CM[bmask] = data[bmask]
    CM = CM.sum(axis=1)
    CM = CM/N
    all_CMs = CM.copy()
    """
    for frame in range(data.shape[2]):
        for row in range(data.shape[0]):
            ###print(frame, row)
            ###mask = data[row, :, frame] < (CM_tilde[row, frame]+threshold)
            ###print(mask)
            ##mask = bmask[row,:,frame]
            ##CM = data[row, mask, frame].sum()
            ##CM += (CM_tilde[row, frame]*(~mask)).sum()

            ##CM = CM/N
            ##all_CMs[row, frame] = CM
            ## calculate CM corrected value
            #for col in range(data.shape[1]):
            #    data_cm[row, col, frame] = data[row, col, frame] - CM[row,frame]
        print "frame: ", frame
    """
    data_cm = data - CM[:,None,:]
    return data_cm, all_CMs

这个小测试用例的输出匹配,最重要的是帮助我得到正确的尺寸。

threshold = .1
data = np.arange(4*3*2,dtype=float).reshape(4,3,2)

【讨论】:

    猜你喜欢
    • 2013-07-21
    • 2018-08-26
    • 2018-12-05
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-04-29
    相关资源
    最近更新 更多