【问题标题】:Concat tensors in PyTorchPyTorch 中的连接张量
【发布时间】:2019-07-10 16:20:06
【问题描述】:

我有一个名为 data 的张量,形状为 [128, 4, 150, 150],其中 128 是批量大小,4 是通道数,最后两个维度是高度和宽度。我有另一个名为fake 的张量,形状为[128, 1, 150, 150]

我想从data的第二维中删除最后一个list/array;数据的形状现在是[128, 3, 150, 150];并将其与fake 连接起来,将连接的输出维度设为[128, 4, 150, 150]

换句话说,我想将data 的前三个维度与fake 连接起来,得到一个4 维张量。

我正在使用 PyTorch,遇到了 torch.cat()torch.stack() 函数

这是我编写的示例代码:

fake_combined = []
        for j in range(batch_size):
            fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)

但我在行中遇到错误:

fake_combined = torch.tensor(fake_combined, dtype=torch.float32)

错误是:

ValueError: only one element tensors can be converted to Python scalars

另外,如果我打印fake_combined 的形状,我得到的输出是[128,] 而不是[128, 4, 150, 150]

当我打印fake_combined[0] 的形状时,我得到的输出为[4, 150, 150],这与预期的一样。

所以我的问题是,为什么我不能使用torch.tensor() 将列表转换为张量。我错过了什么吗?有没有更好的方法来做我打算做的事情?

任何帮助将不胜感激!谢谢!

【问题讨论】:

    标签: python machine-learning pytorch tensor


    【解决方案1】:

    @rollthedice32 的回答非常好。出于教育目的,这里使用torch.cat

    a = torch.rand(128, 4, 150, 150)
    b = torch.rand(128, 1, 150, 150)
    
    # Cut out last dimension
    a = a[:, :3, :, :]
    # Concatenate in 2nd dimension
    result = torch.cat([a, b], dim=1)
    print(result.shape)
    # => torch.Size([128, 4, 150, 150])
    

    【讨论】:

      【解决方案2】:

      您也可以只分配给该特定维度。

      orig = torch.randint(low=0, high=10, size=(2,3,2,2))
      fake = torch.randint(low=111, high=119, size=(2,1,2,2))
      orig[:,[2],:,:] = fake
      

      原来的之前

      tensor([[[[0, 1],
            [8, 0]],
      
           [[4, 9],
            [6, 1]],
      
           [[8, 2],
            [7, 6]]],
      
      
          [[[1, 1],
            [8, 5]],
      
           [[5, 0],
            [8, 6]],
      
           [[5, 5],
            [2, 8]]]])
      

      假的

      tensor([[[[117, 115],
            [114, 111]]],
      
      
          [[[115, 115],
            [118, 115]]]])
      

      原版之后

      tensor([[[[  0,   1],
            [  8,   0]],
      
           [[  4,   9],
            [  6,   1]],
      
           [[117, 115],
            [114, 111]]],
      
      
          [[[  1,   1],
            [  8,   5]],
      
           [[  5,   0],
            [  8,   6]],
      
           [[115, 115],
            [118, 115]]]])
      

      希望这会有所帮助! :)

      【讨论】:

        猜你喜欢
        • 2021-10-30
        • 2021-06-18
        • 2018-10-29
        • 2021-12-01
        • 2021-10-11
        • 2020-09-09
        • 2021-02-07
        • 1970-01-01
        • 2019-07-12
        相关资源
        最近更新 更多