【发布时间】:2021-04-29 19:30:54
【问题描述】:
我正在尝试找到一种不使用 for 循环的方法。
假设我有一个多维张量t0:
bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
这有形状:torch.Size([4, 10, 16])
我还有另一个张量 labels,它是 seq 维度中的一组 5 个随机索引:
labels = torch.randint(0, seq, size=[bs, sample])
所以它的形状为torch.Size([4, 5])。这用于索引t0 的seq 维度。
我想要做的是循环使用 labels 张量在批处理维度上进行收集。我的蛮力解决方案是这样的:
t1 = torch.empty((bs, sample, v))
for b in range(bs):
for idx0, idx1 in enumerate(labels[b]):
t1[b, idx0, :] = t0[b, idx1, :]
导致张量 t1 的形状为:torch.Size([4, 5, 16])
在 pytorch 中是否有更惯用的方法?
【问题讨论】: