【问题标题】:Copying data from one tensor to another using bit masking使用位掩码将数据从一个张量复制到另一个张量
【发布时间】:2019-05-17 19:02:25
【问题描述】:
import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)

torch.Size([2])
张量([0., 7., 0., 9., 0.])

我很难理解这是如何工作的。我最初认为上面的代码使用了花式索引,但我意识到来自 c 张量的值被复制对应于标记为 1 的索引。另外,如果我没有指定 b as uint8 那么上面的代码不起作用。有人可以解释一下上述代码的机制吗?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    使用数组进行索引的工作方式与我知道的 numpy 和大多数其他矢量化数学包中的相同。有两种情况:

      1234563 /em> a (a[i]) 的值,b (b[i]) 中的对应值非零。这些值与原始a 具有别名,因此如果您修改它们,它们对应的位置也会发生变化。
    1. 可用于索引的替代类型是int64 数组,在这种情况下a[b] 创建一个形状为(*b.shape, *a.shape[1:]) 的数组。它的结构就好像b (b[i]) 的每个元素都被a[i] 替换了一样。换句话说,您通过指定应从a 的哪些索引获取数据来创建一个新数组。同样,这些值与原始 a 具有别名,因此如果您修改 a[b],则每个 ia[b[i]] 的值都会改变。 this 问题中显示了一个示例用例。

    这两种模式在integer array indexingboolean array indexing 中针对numpy 进行了解释,对于后者,您必须记住pytorch 使用uint8 代替bool

    另外,如果您的目标是将数据从一个张量复制到另一个张量,您必须记住,像 a[ixs] = b[ixs] 这样的操作是就地操作(a 已就地修改),我不会玩与autograd很好。如果要进行不适当的掩蔽,请使用torch.wherethis 答案中显示了一个示例用例。

    【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-06-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-01-20
    • 2019-09-04
    • 1970-01-01
    相关资源
    最近更新 更多