【问题标题】:Pytorch tensor indexingPytorch 张量索引
【发布时间】:2019-11-26 00:01:58
【问题描述】:

我目前正在将一些代码从 tensorflow 转换为 pytorch,我遇到了 tf.gather func 的问题,在 pytorch 中没有直接转换它的函数。

我要做的基本上是索引,我有两个张量,张量形状为[minibatch, 60, 2] 和索引张量[minibatch, 8],比如说第一个张量是张量A,第二个是B .

在Tensorflow中直接用tf.gather(A, B, batch_dims=1)转换

如何在 pytorch 中实现这一点?

我已经尝试过A[B] 索引。这个好像不行

A[0]B[0] 有效,但形状的输出是[8, 2]

我需要[minibatch, 8, 2]的形状

如果我像 [stack, 8, 2] 这样堆叠张量,它可能会起作用,但我不知道该怎么做

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

[minibatch, 8, 2]的输出形状

【问题讨论】:

    标签: python indexing pytorch


    【解决方案1】:

    我想你在找torch.gather

    out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))
    

    【讨论】:

    • @AydenLee 很高兴我能提供帮助。如果此答案对您有用,请考虑通过单击旁边的“v”图标“接受”它
    猜你喜欢
    • 2020-03-28
    • 2021-11-04
    • 2020-07-20
    • 2019-09-15
    • 2020-09-26
    • 2021-06-27
    • 2020-12-06
    • 2019-02-05
    • 1970-01-01
    相关资源
    最近更新 更多