【发布时间】:2019-11-26 00:01:58
【问题描述】:
我目前正在将一些代码从 tensorflow 转换为 pytorch,我遇到了 tf.gather func 的问题,在 pytorch 中没有直接转换它的函数。
我要做的基本上是索引,我有两个张量,张量形状为[minibatch, 60, 2] 和索引张量[minibatch, 8],比如说第一个张量是张量A,第二个是B .
在Tensorflow中直接用tf.gather(A, B, batch_dims=1)转换
如何在 pytorch 中实现这一点?
我已经尝试过A[B] 索引。这个好像不行
和A[0]B[0] 有效,但形状的输出是[8, 2]
我需要[minibatch, 8, 2]的形状
如果我像 [stack, 8, 2] 这样堆叠张量,它可能会起作用,但我不知道该怎么做
tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great
[minibatch, 8, 2]的输出形状
【问题讨论】: