【发布时间】:2020-10-03 02:46:29
【问题描述】:
我想知道下面的代码是否有更有效的替代方法,而不使用第 4 行中的“for”循环?
import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)
基本上,我要做的是通过d 掩码张量创建一个n,这样每一行中的k 随机元素都是True。
【问题讨论】:
标签: pytorch