【问题标题】:How can I vectorize this (numpy) operation in python?如何在 python 中对这个(numpy)操作进行矢量化?
【发布时间】:2019-03-18 00:40:20
【问题描述】:

我有两个形状为(batch, dim) 的向量,我试图将它们相减。目前我正在使用一个简单的循环从 1 中减去基于第二个向量(即label)的向量中的特定条目(即error):

per_ts_loss=0
for i, idx in enumerate(np.argmax(label, axis=1)):
    error[i, idx] -=1
    per_ts_loss += error[i, idx]

如何矢量化它?

例如,错误和标签可能如下所示:

error :
array([[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
       [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]])
label:
    array([[0, 0, 0, 1, 0 ],
           [0, 1, 0, 0, 0]])

对于本示例,运行以下代码会产生以下结果:

for i, idx in enumerate(np.argmax(label,axis=1)):
    error[i,idx] -=1
    ls_loss += error[i,idx]

结果:

error: 
 [[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
 [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]]
label: 
 [[ 0.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  0.]]

error(indexes 3 and 1 are changed): 
[[ 0.5488135   0.71518937  0.60276338 -0.45511682  0.4236548 ]
 [ 0.64589411 -0.56241279  0.891773    0.96366276  0.38344152]]
per_ts_loss: 
 -1.01752960574

这是代码本身:https://ideone.com/e1k8ra

我不知道如何使用np.argmax 的结果,因为结果是一个新的索引向量,不能简单地像这样使用:

 error[:, np.argmax(label, axis=1)] -=1

所以我被困在这里了!

【问题讨论】:

  • 您能分享一个数组样本和预期输出吗?
  • 好的,我会在几秒钟内编辑问题。
  • 另外,在最后一行,e 是什么,错误?
  • @yatu:那是一个错字,我改正了。还提供了示例
  • 标签总是0还是1?

标签: python numpy


【解决方案1】:

替换:

error[:, np.argmax(label, axis=1)] -=1

与:

error[np.arange(error.shape[0]), np.argmax(label, axis=1)] -=1

当然

loss = error[np.arange(error.shape[0]), np.argmax(label, axis=1)].sum()

在您的示例中,您正在更改和求和 error[0,3]error[1,1],或者简而言之 error[[0,1],[3,1]]

【讨论】:

    【解决方案2】:

    也许是这样的:

    import numpy as np
    
    
    error = np.array([[0.32783139, 0.29204386, 0.0572163 , 0.96162543, 0.8343454 ],
           [0.67308787, 0.27715222, 0.11738748, 0.091061  , 0.51806117]])
    
    label= np.array([[0, 0, 0, 1, 0 ],
               [0, 1, 0, 0, 0]])
    
    
    
    def f(error, label):
        per_ts_loss=0
        t=np.zeros(error.shape)
        argma=np.argmax(label, axis=1)
        t[[i for i in range(error.shape[0])],argma]=-1
        print(t)
        error+=t
        per_ts_loss += error[[i for i in range(error.shape[0])],argma]
    
    
    f(error, label)
    

    输出:

    [[ 0.  0.  0. -1.  0.]
     [ 0. -1.  0.  0.  0.]]
    

    【讨论】:

    • 感谢您的努力,但没有。 O 不要减去标签,标签用于从错误中减去 1(这是为了显示我们距离正确答案(即 1 )还有多少,这也不是矢量化实现。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-10-09
    • 2016-05-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多