【问题标题】:Why does torch.scatter requires a smaller shape for indices than values?为什么 torch.scatter 需要比值更小的索引形状?
【发布时间】:2021-07-06 17:06:04
【问题描述】:

here 已经提出了类似的问题,但我认为该解决方案不适合我的情况。

我只是想知道为什么不能进行torch.scatter 操作,其中我的索引张量大于我的值张量。就我而言,我有重复的索引,例如以下值张量a和索引张量idx

a = torch.tensor([[0, 1, 0, 0],
                  [0, 0, 1, 0]])

idx = torch.tensor([[1, 1, 2, 3, 3],
                    [0, 0, 1, 2, 2]])

a.scatter(-1, idx, 1) 返回:

RuntimeError: 期望索引 [2, 5] 小于 self [2, 4] 除了维度 1 并且小于 src [2, 4]

还有其他方法可以实现吗?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    不是解决方案,而是解决方法:

    a = torch.tensor([[0, 1, 0, 0],
                      [0, 0, 1, 0]])
    
    idx = torch.tensor([[1, 1, 2, 3, 3],
                        [0, 0, 1, 2, 2]])
    
    rows = torch.arange(0, a.size(0))[:,None]
    n_col = idx.size(1)
    a[rows.repeat(1, n_col), idx] = 1
    

    rows.repeat(1, n_col) 将行索引赋予idx 中对应的列索引。

    【讨论】:

      猜你喜欢
      • 2014-12-08
      • 1970-01-01
      • 2021-03-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多