【发布时间】:2015-07-09 22:13:25
【问题描述】:
我有一个张量 probs 和 probs.shape = (max_time, num_batches, num_labels)。
我有一个张量targets 和targets.shape = (max_seq_len, num_batches),其中的值是标签索引,即probs 中的第三维。
现在我想得到一个张量probs_y 和probs.shape = (max_time, num_batches, max_seq_len),其中第三维是targets 中的索引。基本上
probs_y[:,i,:] = probs[:,i,targets[:,i]]
所有0 <= i < num_batches。
我怎样才能做到这一点?
here 发布了类似的解决方案。
如果我理解正确的话,那里的解决方案是:
probs_y = probs[:,T.arange(targets.shape[1])[None,:],targets]
但这似乎不起作用。我得到:
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices.
另外,时间T.arange 的创建不是有点昂贵吗?特别是当我尝试通过真正使其成为一个完整的密集整数数组来解决问题时。应该有更好的办法。
也许theano.map?但据我了解,这不会并行化代码,所以这也不是解决方案。
【问题讨论】:
-
刚刚意识到我所做的与您的行不同的是我在
T.arange和targets中都调换了轴。这很奇怪。在这种情况下,您的也应该有效。 -
好的,你这样做的方式也有效,我更新了我的答案。所以问题出在其他地方。 Theano 版本或与此特定操作无关的内容 - 尽管给出错误消息,但后者似乎不太可能。