【发布时间】:2021-03-18 21:16:50
【问题描述】:
我需要以一定的概率将张量new的元素插入到张量old中,为了简单起见,假设为0.8。 基本上这就是 masked_fill 会做的事情,但它只适用于一维张量。 其实我在做
prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
mask = prob < 0.8
dim1, dim2, dim3, dim4 = new.shape
for a in range(dim1):
for b in range(dim2):
for c in range(dim3):
for d in range(dim4):
old[a][b][c][d] = old[a][b][c][d] if mask[a][b][c][d] else new[a][b][c][d]
这太糟糕了。我想要类似的东西
prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
mask = prob < 0.8
old = trgs.multidimensional_masked_fill(mask, new)
【问题讨论】: