【问题标题】:torch logical indexing of tensor张量的火炬逻辑索引
【发布时间】:2016-07-20 11:38:43
【问题描述】:

我正在寻找一种优雅的方式来选择满足某些约束的火炬张量子集。 例如,假设我有:

A = torch.rand(10,2)-1

S 是一个 10x1 的张量,

sel = torch.ge(S,5) -- this is a ByteTensor

我希望能够做逻辑索引,如下:

A1 = A[sel]

但这不起作用。 所以有一个index 函数接受LongTensor,但我找不到将S 转换为LongTensor 的简单方法,除了以下内容:

sel = torch.nonzero(sel)

它返回一个 K x 2 张量(K 是 S >= 5 的值的数量)。那么我必须将它转换为一维数组,这最终允许我索引 A:

A:index(1,torch.squeeze(sel:select(2,1)))

这很麻烦;在例如Matlab 我所要做的就是

A(S>=5,:)

谁能提出更好的方法?

【问题讨论】:

    标签: indexing torch


    【解决方案1】:

    一种可能的选择是:

    sel = S:ge(5):expandAs(A)   -- now you can use this mask with the [] operator
    A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor
    

    例子:

    > A = torch.rand(3,2)-1
    -0.0047 -0.7976
    -0.2653 -0.4582
    -0.9713 -0.9660
    [torch.DoubleTensor of size 3x2]
    
    > S = torch.Tensor{{6}, {1}, {5}}
     6
     1
     5
    [torch.DoubleTensor of size 3x1]
    
    > sel = S:ge(5):expandAs(A)
    1  1
    0  0
    1  1
    [torch.ByteTensor of size 3x2]
    
    > A[sel]
    -0.0047
    -0.7976
    -0.9713
    -0.9660
    [torch.DoubleTensor of size 4]
    
    > A[sel]:unfold(1, 2, 2)
    -0.0047 -0.7976
    -0.9713 -0.9660
    [torch.DoubleTensor of size 2x2]
    

    【讨论】:

    • 很好,展开是我所缺少的。虽然我认为它可以通过重塑更直观地完成。谢谢!
    【解决方案2】:

    有两种更简单的选择:

    1. 使用maskedSelect:

      result=A:maskedSelect(your_byte_tensor)

    2. 使用简单的元素乘法,例如

      result=torch.cmul(A,S:gt(0))

    如果您需要保持原始矩阵的形状( A),第二个非常有用,例如在反向传播中选择层中的神经元。但是,由于只要 ByteTensor 规定的条件不适用,它就会在结果矩阵中放置零,因此您不能使用它来计算乘积(或中位数等)。第一个只返回满足条件的元素,所以这是我用来计算产品或中位数或任何其他我不想要零的东西。

    【讨论】:

      猜你喜欢
      • 2022-07-20
      • 2022-10-23
      • 2016-06-27
      • 2020-08-02
      • 2021-07-08
      • 1970-01-01
      • 2020-06-11
      • 2020-10-26
      • 2021-06-18
      相关资源
      最近更新 更多