【问题标题】:2d array as index in Pytorch二维数组作为 Pytorch 中的索引
【发布时间】:2018-12-20 15:15:34
【问题描述】:

我想使用一组规则“增长”一个矩阵。

规则示例:

0->[[1,1,1],[0,0,0],[2,2,2]],
1->[[2,2,2],[2,2,2],[2,2,2]],
2->[[0,0,0],[0,0,0],[0,0,0]]

增长矩阵的例子:

[[0]]->[[1,1,1],[0,0,0],[2,2,2]]->
[[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],[2,2,2,2,2,2,2,2,2],
[1,1,1,1,1,1,1,1,1],[0,0,0,0,0,0,0,0,0],[2,2,2,2,2,2,2,2,2],
[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0]]

这是我一直试图在 Pytorch 中工作的代码

rules = np.random.randint(256,size=(10,256,3,3,3))
rules_tensor = torch.randint(256,size=(10,
            256, 3, 3, 3),
            dtype=torch.uint8, device = torch.device('cuda'))

rules = rules[0]
rules_tensor = rules_tensor[0]

seed = np.array([[128]])
seed_tensor = seed_tensor = torch.cuda.ByteTensor([[128]])

decode = np.empty((3**3, 3**3, 3))
decode_tensor = torch.empty((3**3,
                3**3, 3), dtype=torch.uint8,
                device = torch.device('cuda'))

for i in range(3):
    grow = seed
    grow_tensor = seed_tensor
    for j in range(1,4):
        grow = rules[grow,:,:,i].reshape(3**j,-1)
        grow_tensor = rules_tensor[grow_tensor,:,:,i].reshape(3**j,-1)

    decode[..., i] = grow
    decode_tensor[..., i] = grow_tensor

在这一行中,我似乎无法像在 Numpy 中那样选择索引:

grow = rules[grow,:,:,i].reshape(3**j,-1)

有没有办法在 Pytorch 中执行以下操作?

【问题讨论】:

    标签: python numpy matrix indexing pytorch


    【解决方案1】:

    您可以考虑使用torch.index_select(),在重塑结果之前展平您的索引张量:

    代码:

    import torch
    import numpy as np
    
    rules_np = np.array([
        [[1,1,1],[0,0,0],[2,2,2]],  # for value 0
        [[2,2,2],[2,2,2],[2,2,2]],  # for value 1
        [[0,0,0],[0,0,0],[0,0,0]]]) # for value 2, etc.
    rules = torch.from_numpy(rules_np).long()
    rule_shape = rules[0].shape
    
    seed = torch.zeros(1).long()
    num_growth = 2
    print("Seed:")
    print(seed)
    
    grow = seed
    for i in range(num_growth):
        grow = (torch.index_select(rules, 0, grow.view(-1))
                .view(grow.shape + rule_shape)
                .squeeze())
        print("Growth #{}:".format(i))
        print(grow)
    

    日志:

    Seed:
    tensor([ 0])
    Growth #0:
    tensor([[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]])
    Growth #1:
    tensor([[[[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]],
             [[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]],
             [[ 2,  2,  2], [ 2,  2,  2], [ 2,  2,  2]]],
    
            [[[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]],
             [[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]],
             [[ 1,  1,  1], [ 0,  0,  0], [ 2,  2,  2]]],
    
            [[[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]],
             [[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]],
             [[ 0,  0,  0], [ 0,  0,  0], [ 0,  0,  0]]]])
    

    【讨论】:

      猜你喜欢
      • 2019-09-15
      • 1970-01-01
      • 2016-06-22
      • 2022-01-05
      • 2018-06-08
      • 2012-04-01
      • 2023-03-07
      • 2020-07-30
      相关资源
      最近更新 更多