【问题标题】:Pytorch cut 2d array by lengths on first dimensionPytorch 在第一维上按长度切割二维数组
【发布时间】:2020-06-17 10:07:05
【问题描述】:

我有一个二维数组,假设大小为torch.tensor(batch_size, 1000)。来自第二维的 1000 数组实际上是可变长度的。我有第二个大小为[batch_size] 的数组,其中包含每行的长度...

这里是一个示例代码sn-p:

# preds is the 2d array of size [batch_size, 1000]
# lengths is a 1d array containing the lengths of each row of preds
res_pred = []
for i in range(len(preds)):
    length = lengths[i].item()
    res_pred += [preds[i][:length]]

result = torch.cat(res_pred).flatten()

我对我的目标做同样的事情,然后我可以对两者应用损失函数。

我想知道是否有一个向量化操作可以提取所有可变长度的batch_size 向量和torch.cat 一起。现在我在第一个维度上循环,但这感觉很慢。

谢谢,

【问题讨论】:

  • 你能添加一个你的数组的例子吗?此外,使用循环添加慢速代码将有助于理解预期的输出。
  • 你如何期望torch.cat 不同长度的向量?你要创建一个大小为(1, sum(vector_lengths))单个向量吗?
  • @AndreasK.,我添加了一个代码示例。我应该先这样做。谢谢
  • 我不确定有没有一种有效的方法来做你想做的事。作为替代方案,您是否考虑过创建一个与res_pred 形状相同的掩码并仅计算掩码中具有 1 的值的损失? This answer 可能会帮助您创建面具。

标签: python arrays numpy multidimensional-array pytorch


【解决方案1】:

您可以使用lengths[i] 给出的第 i 行中 True 的数量创建一个 2D 掩码张量。这是一个例子:

batch_size = 6
n = 5

preds = torch.arange(batch_size * n).reshape(batch_size, n)
# tensor([[ 0,  1,  2,  3,  4],
#         [ 5,  6,  7,  8,  9],
#         [10, 11, 12, 13, 14],
#         [15, 16, 17, 18, 19],
#         [20, 21, 22, 23, 24],
#         [25, 26, 27, 28, 29]])

#lengths = np.random.randint(0, n+1, batch_size)
lengths = torch.randint(0, n+1, (batch_size, ))
# tensor([2, 0, 5, 3, 3, 2])

让我们创建蒙版并得到我们的结果(可能有更好的方法来创建这样的蒙版,但这就是我想出的):

#mask = np.tile(range(n), (batch_size,1)) < lengths[:,None]
mask = torch.arange(n).repeat((batch_size,1)) < lengths[:, None]
# tensor([[ True,  True, False, False, False],
#        [False, False, False, False, False],
#        [ True,  True,  True,  True,  True],
#        [ True,  True,  True, False, False],
#        [ True,  True,  True, False, False],
#        [ True,  True, False, False, False]])

#result = preds[mask]
result = torch.masked_select(preds, mask)
# tensor([0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 25, 26])

这会产生与您的代码相同的结果:

res_pred = []
for i in range(len(preds)):
    length = lengths[i].item()
    res_pred += [preds[i][:length]]

result = torch.cat(res_pred).flatten()

【讨论】:

  • 我编辑了答案,使其使用 pytorch 而不是 numpy(numpy 的解决方案已注释)。
猜你喜欢
  • 1970-01-01
  • 2012-07-30
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多