【发布时间】:2019-03-18 00:40:20
【问题描述】:
我有两个形状为(batch, dim) 的向量,我试图将它们相减。目前我正在使用一个简单的循环从 1 中减去基于第二个向量(即label)的向量中的特定条目(即error):
per_ts_loss=0
for i, idx in enumerate(np.argmax(label, axis=1)):
error[i, idx] -=1
per_ts_loss += error[i, idx]
如何矢量化它?
例如,错误和标签可能如下所示:
error :
array([[ 0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
[ 0.64589411 0.43758721 0.891773 0.96366276 0.38344152]])
label:
array([[0, 0, 0, 1, 0 ],
[0, 1, 0, 0, 0]])
对于本示例,运行以下代码会产生以下结果:
for i, idx in enumerate(np.argmax(label,axis=1)):
error[i,idx] -=1
ls_loss += error[i,idx]
结果:
error:
[[ 0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
[ 0.64589411 0.43758721 0.891773 0.96366276 0.38344152]]
label:
[[ 0. 0. 0. 1. 0.]
[ 0. 1. 0. 0. 0.]]
error(indexes 3 and 1 are changed):
[[ 0.5488135 0.71518937 0.60276338 -0.45511682 0.4236548 ]
[ 0.64589411 -0.56241279 0.891773 0.96366276 0.38344152]]
per_ts_loss:
-1.01752960574
这是代码本身:https://ideone.com/e1k8ra
我不知道如何使用np.argmax 的结果,因为结果是一个新的索引向量,不能简单地像这样使用:
error[:, np.argmax(label, axis=1)] -=1
所以我被困在这里了!
【问题讨论】:
-
您能分享一个数组样本和预期输出吗?
-
好的,我会在几秒钟内编辑问题。
-
另外,在最后一行,
e是什么,错误? -
@yatu:那是一个错字,我改正了。还提供了示例
-
标签总是0还是1?