【问题标题】:PyTorch how to do gathers over multiple dimensionsPyTorch 如何在多个维度上进行收集
【发布时间】:2021-04-29 19:30:54
【问题描述】:

我正在尝试找到一种不使用 for 循环的方法。

假设我有一个多维张量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

这有形状:torch.Size([4, 10, 16])

我还有另一个张量 labels,它是 seq 维度中的一组 5 个随机索引:

labels = torch.randint(0, seq, size=[bs, sample])

所以它的形状为torch.Size([4, 5])。这用于索引t0seq 维度。

我想要做的是循环使用 labels 张量在批处理维度上进行收集。我的蛮力解决方案是这样的:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

导致张量 t1 的形状为:torch.Size([4, 5, 16])

在 pytorch 中是否有更惯用的方法?

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    您可以在此处使用fancy indexing 选择所需的张量部分。

    本质上,如果您预先生成传达您的访问模式的索引数组,您可以直接使用它们来提取张量的一些切片。每个维度的索引数组的形状应与您要提取的输出张量或切片的形状相同。

    i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
    j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
    k = torch.arange(v)                    # shape = [v, ]
    
    # Get result as
    t1 = t0[i, j, k]
    

    注意以上 3 个张量的形状。 Broadcasting 在张量的前面附加了额外的维度,因此基本上将 k 重塑为 [1, 1, v] 形状,这使得它们中的所有 3 个都兼容元素操作。

    在广播 (i, j, k) 之后将产生 3 个 [bs, sample, v] 形状的数组,这些数组将(按元​​素)索引您的原始张量以产生形状为 [bs, sample, v] 的输出张量 t1

    【讨论】:

      【解决方案2】:

      你可以这样做:

      t1 = t0[[[b] for b in range(bs)], labels]
      

      t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])
      

      【讨论】:

        猜你喜欢
        • 2020-07-24
        • 1970-01-01
        • 2019-05-27
        • 2021-04-23
        • 1970-01-01
        • 2021-07-26
        • 1970-01-01
        • 2021-02-10
        • 1970-01-01
        相关资源
        最近更新 更多