【发布时间】:2021-11-28 20:15:03
【问题描述】:
我怎样才能执行 批处理 masked_select?
给定:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
期望的输出是:
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这是一种可能的方法:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
masks = torch.tensor([[ True, False, False, False, True],
[ True, False, True, True, False]])
for i in range(x.size(0)):
mask = masks[i]
s = torch.masked_select(x[i], mask)
ones[i][:s.size(0)] = s
有其他解决方案吗?
【问题讨论】:
-
@ivan 我看到你也遇到过类似的问题。你有解决方案吗?谢谢!
标签: deep-learning pytorch pytorch-dataloader