正如上面 cmets 中的 @Marcel 所解释的,您可以首先将第一个 m 值设置为值 k,然后按置换索引进行索引以获得随机张量:
>>> n = 10; m = 3; k = 1
>>> x = torch.zeros(n, n)
>>> x[:, :m] = k
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.]])
使用torch.randperm 获取逐行列排列:
>>> perm = torch.stack([torch.randperm(10) for _ in range(len(x))])
tensor([[8, 0, 3, 2, 1, 6, 9, 4, 5, 7],
[5, 7, 1, 4, 8, 0, 6, 9, 2, 3],
[2, 1, 9, 7, 0, 8, 6, 3, 5, 4],
[1, 3, 5, 8, 7, 6, 9, 4, 2, 0],
[7, 6, 0, 5, 2, 9, 1, 8, 4, 3],
[5, 0, 6, 8, 1, 9, 2, 4, 3, 7],
[4, 0, 6, 5, 8, 1, 3, 7, 2, 9],
[5, 3, 4, 9, 0, 1, 7, 6, 8, 2],
[5, 7, 9, 3, 2, 6, 8, 0, 4, 1],
[2, 7, 4, 6, 3, 0, 9, 8, 5, 1]])
然后使用torch.gather 索引张量x 和perm:
>>> x.gather(dim=0, index=perm)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])
或者,您可以直接使用 torch.scatter 和 value 关键字参数:
>>> torch.zeros(n, n).scatter(dim=0, index=perm, value=1)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])
如果m 本身就是张量,您可以使用torch.arange 和torch.where 的组合找到解决方法:
首先对位置进行编码:
>>> d = torch.arange(n)[None].repeat(n,1)
>>> x = torch.where(d+m>n, 0, 1)
tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
像以前一样构造排列:
>>> perm = torch.stack([torch.randperm(10) for _ in range(n)])
tensor([[2, 5, 7, 0, 4, 1, 3, 6, 8, 9],
[7, 4, 9, 5, 6, 0, 3, 1, 2, 8],
[5, 1, 4, 9, 0, 3, 2, 6, 7, 8],
[9, 6, 0, 2, 3, 1, 7, 5, 4, 8],
[3, 5, 4, 6, 0, 7, 9, 8, 2, 1],
[5, 7, 8, 6, 9, 2, 0, 4, 3, 1],
[8, 3, 9, 0, 6, 2, 5, 7, 4, 1],
[2, 9, 4, 3, 7, 8, 1, 0, 6, 5],
[5, 4, 8, 3, 2, 9, 7, 1, 6, 0],
[8, 7, 3, 6, 5, 4, 2, 0, 9, 1]])
然后分散在x:
>>> x.scatter(dim=0, index=perm, value=1)
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])