【发布时间】:2019-05-17 19:02:25
【问题描述】:
import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)
torch.Size([2])
张量([0., 7., 0., 9., 0.])
我很难理解这是如何工作的。我最初认为上面的代码使用了花式索引,但我意识到来自 c 张量的值被复制对应于标记为 1 的索引。另外,如果我没有指定 b as uint8 那么上面的代码不起作用。有人可以解释一下上述代码的机制吗?
【问题讨论】:
标签: pytorch