【问题标题】:Extracting 3D patches from 3D images in both overlapping and nonoverlapping process and recovering the image back在重叠和非重叠过程中从 3D 图像中提取 3D 补丁并恢复图像
【发布时间】:2019-07-21 22:45:05
【问题描述】:

我正在处理 172x220x156 形状的 3D 图像。要将图像输入网络进行输出,我需要从图像中提取大小为 32x32x32 的补丁,然后将其添加回来以再次获取图像。 由于我的图像尺寸不是补丁大小的倍数,因此我必须得到重叠的补丁。 我想知道该怎么做。

我在 PyTorch 中工作,有一些选项,例如 unfoldfold,但我不确定它们是如何工作的。

【问题讨论】:

    标签: python 3d pytorch patch mri


    【解决方案1】:

    您的所有数据都完全是172x220x156吗?如果是这样,似乎您可以使用 for 循环并索引张量来获取 32x32x32 块,对吗? (可能硬编码一些东西)。

    但是,我无法完全回答这个问题,因为不清楚您希望如何组合结果。说清楚,这是你的目标吗?

    1) 从图像中获取32x32x32 补丁 2)对其进行一些任意处理 3)将该补丁保存到正确索引处的某个result 4) 重复

    如果是这样,您打算如何组合重叠的补丁?总结他们?取平均值?

    但是 - 索引:

    out_tensor = torch.zeros_like(input)
    for i_idx in [0, 32, 64, 96, 128, 140]:
        for j_idx in [0, 32, 64, 96, 128, 160, 188]:
            for k_idx in [0, 32, 64, 96, 124]:
                input = tensor[i_idx, j_idx, k_idx]
                output = your_model(input)
                out_tensor[i_idx, j_idx, k_idx] = output
    

    这根本没有优化,但我想大部分计算将是实际的神经网络,而且没有办法解决这个问题,所以优化可能毫无意义。

    【讨论】:

    • 是的,所有数据都是 172x220x156 我必须从图像中创建补丁,因为您无法将整个图像插入网络,在获得输出后,将它们添加回来。我可以在输入图像中进行填充。但我想在重叠区域取补丁并进行强度平均。
    【解决方案2】:

    你可以使用unfoldpytorch docs):

    batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
    x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)
    
    kernel_c, kernel_h, kernel_w = 32, 32, 32
    step = 32
    
    # Tensor.unfold(dimension, size, step)
    windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
    print(windows_unpacked.shape)
    # result: torch.Size([1, 5, 6, 4, 32, 32, 32])
    
    windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
    print(windows.shape)
    # result: torch.Size([120, 32, 32, 32])
    

    【讨论】:

      【解决方案3】:

      要提取(重叠)补丁并重建输入形状,我们可以使用torch.nn.functional.unfold 和逆运算torch.nn.functional.fold。这些方法只处理 4D 张量或 2D 图像,但是您可以使用这些方法一次处理一个维度。

      几点说明:

      1. 这种方式需要 pytorch 的折叠/展开方法,遗憾的是我还没有在 TF api 中找到类似的方法。

      2. 我们可以通过两种方式提取补丁,它们的输出是相同的。这些方法称为extract_patches_3dextract_patches_3ds,其中X 是维数。后者使用 torch.Tensor.unfold() 并且代码行数更少。 (输出是一样的,只是不能使用膨胀)

      3. extract_patches_Xdcombine_patches_Xd 方法是inverse 方法,combiner 逐步反转提取器的步骤。

      4. 代码行后面是说明维度的注释,例如 (B, C, D, H, W)。使用以下内容:

        1. B: 批量大小
        2. C:频道
        3. D: 深度维度
        4. H: 高度尺寸
        5. W:宽度尺寸
        6. x_dim_in:在提取方法中,这是维度x中的输入像素数。在组合方法中,这是维度x的滑动窗口个数。
        7. x_dim_out:在提取方法中,这是维度x的滑动窗口数。在组合方法中,这是维度x的输出像素数。
      5. 我有一个public notebook to try out the code

      6. get_dim_blocks() 方法是pytorch docs website 上给出的函数,用于计算卷积层的输出形状。

      7. 请注意,如果您有重叠的补丁并将它们组合起来,重叠的元素将被求和。如果您想再次获得初始输入,有一种方法。

        1. 使用torch.ones_like(patches_tensor) 创建与补丁相似大小的张量。
        2. 将补丁组合成具有相同输出形状的完整图像。 (这会为重叠元素创建一个计数器)。
        3. 将组合图像除以组合图像,这应该会反转任何元素的双重求和。 (3D): 我们需要使用 2 个foldunfold,我们首先将fold 应用于D 维度,并通过将内核设置为1、填充为0、步幅为1 来保持WH 不变并膨胀到 1。在我们查看张量并折叠 HW 维度之后。展开反向进行,从HW 开始,然后是D
      def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
          if isinstance(kernel_size, int):
              kernel_size = (kernel_size, kernel_size, kernel_size)
          if isinstance(padding, int):
              padding = (padding, padding, padding, padding, padding, padding)
          if isinstance(stride, int):
              stride = (stride, stride, stride)
      
          channels = x.shape[1]
      
          x = torch.nn.functional.pad(x, padding)
          # (B, C, D, H, W)
          x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
          # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
          x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
          # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
          return x
      
      def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
          if isinstance(kernel_size, int):
              kernel_size = (kernel_size, kernel_size, kernel_size)
          if isinstance(padding, int):
              padding = (padding, padding, padding)
          if isinstance(stride, int):
              stride = (stride, stride, stride)
          if isinstance(dilation, int):
              dilation = (dilation, dilation, dilation)
      
          def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
              dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
              return dim_out
      
          channels = x.shape[1]
      
          d_dim_in = x.shape[2]
          h_dim_in = x.shape[3]
          w_dim_in = x.shape[4]
          d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
          h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
          w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
          # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
          
          # (B, C, D, H, W)
          x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
          # (B, C, D, H * W)
      
          x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
          # (B, C * kernel_size[0], d_dim_out * H * W)
      
          x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
          # (B, C * kernel_size[0] * d_dim_out, H, W)
      
          x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
          # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
      
          x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
          # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
      
          x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
          # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
      
          x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
          # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
      
          return x
      
      
      
      def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
          if isinstance(kernel_size, int):
              kernel_size = (kernel_size, kernel_size, kernel_size)
          if isinstance(padding, int):
              padding = (padding, padding, padding)
          if isinstance(stride, int):
              stride = (stride, stride, stride)
          if isinstance(dilation, int):
              dilation = (dilation, dilation, dilation)
      
          def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
              dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
              return dim_out
      
          channels = x.shape[1]
          d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
          d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
          h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
          w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
          # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
      
          x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
          # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
      
          x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
          # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
      
          x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
          # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
      
          x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
          # (B, C * kernel_size[0] * d_dim_in, H, W)
      
          x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
          # (B, C * kernel_size[0], d_dim_in * H * W)
      
          x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
          # (B, C, D, H * W)
          
          x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
          # (B, C, D, H, W)
      
          return x
      
      a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
      print(a.shape)
      print(a)
      b = extract_patches_3d(a, 2, padding=1, stride=1)
      bs = extract_patches_3ds(a, 2, padding=1, stride=1)
      print(b.shape)
      print(b)
      c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
      print(c.shape)
      print(c)
      ones = torch.ones_like(b)
      ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
      print(torch.all(a==c))
      print(c.shape, ones.shape)
      d = c / ones
      print(d)
      print(torch.all(a==d))
      

      输出(3D)

      torch.Size([2, 2, 2, 4, 4])
      tensor([[[[[  1.,   2.,   3.,   4.],
                 [  5.,   6.,   7.,   8.],
                 [  9.,  10.,  11.,  12.],
                 [ 13.,  14.,  15.,  16.]],
      
                [[ 17.,  18.,  19.,  20.],
                 [ 21.,  22.,  23.,  24.],
                 [ 25.,  26.,  27.,  28.],
                 [ 29.,  30.,  31.,  32.]]],
      
      
               [[[ 33.,  34.,  35.,  36.],
                 [ 37.,  38.,  39.,  40.],
                 [ 41.,  42.,  43.,  44.],
                 [ 45.,  46.,  47.,  48.]],
      
                [[ 49.,  50.,  51.,  52.],
                 [ 53.,  54.,  55.,  56.],
                 [ 57.,  58.,  59.,  60.],
                 [ 61.,  62.,  63.,  64.]]]],
      
      
      
              [[[[ 65.,  66.,  67.,  68.],
                 [ 69.,  70.,  71.,  72.],
                 [ 73.,  74.,  75.,  76.],
                 [ 77.,  78.,  79.,  80.]],
      
                [[ 81.,  82.,  83.,  84.],
                 [ 85.,  86.,  87.,  88.],
                 [ 89.,  90.,  91.,  92.],
                 [ 93.,  94.,  95.,  96.]]],
      
      
               [[[ 97.,  98.,  99., 100.],
                 [101., 102., 103., 104.],
                 [105., 106., 107., 108.],
                 [109., 110., 111., 112.]],
      
                [[113., 114., 115., 116.],
                 [117., 118., 119., 120.],
                 [121., 122., 123., 124.],
                 [125., 126., 127., 128.]]]]])
      torch.Size([150, 2, 2, 2, 2])
      tensor([[[[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   1.]]],
      
      
               [[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  1.,   2.]]]],
      
      
      
              [[[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  2.,   3.]]],
      
      
               [[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  3.,   4.]]]],
      
      
      
              [[[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  4.,   0.]]],
      
      
               [[[  0.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   1.],
                 [  0.,   5.]]]],
      
      
      
              ...,
      
      
      
              [[[[124.,   0.],
                 [128.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]],
      
      
               [[[  0., 125.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]]],
      
      
      
              [[[[125., 126.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]],
      
      
               [[[126., 127.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]]],
      
      
      
              [[[[127., 128.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]],
      
      
               [[[128.,   0.],
                 [  0.,   0.]],
      
                [[  0.,   0.],
                 [  0.,   0.]]]]])
      torch.Size([2, 2, 2, 4, 4])
      tensor([[[[[   8.,   16.,   24.,   32.],
                 [  40.,   48.,   56.,   64.],
                 [  72.,   80.,   88.,   96.],
                 [ 104.,  112.,  120.,  128.]],
      
                [[ 136.,  144.,  152.,  160.],
                 [ 168.,  176.,  184.,  192.],
                 [ 200.,  208.,  216.,  224.],
                 [ 232.,  240.,  248.,  256.]]],
      
      
               [[[ 264.,  272.,  280.,  288.],
                 [ 296.,  304.,  312.,  320.],
                 [ 328.,  336.,  344.,  352.],
                 [ 360.,  368.,  376.,  384.]],
      
                [[ 392.,  400.,  408.,  416.],
                 [ 424.,  432.,  440.,  448.],
                 [ 456.,  464.,  472.,  480.],
                 [ 488.,  496.,  504.,  512.]]]],
      
      
      
              [[[[ 520.,  528.,  536.,  544.],
                 [ 552.,  560.,  568.,  576.],
                 [ 584.,  592.,  600.,  608.],
                 [ 616.,  624.,  632.,  640.]],
      
                [[ 648.,  656.,  664.,  672.],
                 [ 680.,  688.,  696.,  704.],
                 [ 712.,  720.,  728.,  736.],
                 [ 744.,  752.,  760.,  768.]]],
      
      
               [[[ 776.,  784.,  792.,  800.],
                 [ 808.,  816.,  824.,  832.],
                 [ 840.,  848.,  856.,  864.],
                 [ 872.,  880.,  888.,  896.]],
      
                [[ 904.,  912.,  920.,  928.],
                 [ 936.,  944.,  952.,  960.],
                 [ 968.,  976.,  984.,  992.],
                 [1000., 1008., 1016., 1024.]]]]])
      tensor(False)
      torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
      tensor([[[[[  1.,   2.,   3.,   4.],
                 [  5.,   6.,   7.,   8.],
                 [  9.,  10.,  11.,  12.],
                 [ 13.,  14.,  15.,  16.]],
      
                [[ 17.,  18.,  19.,  20.],
                 [ 21.,  22.,  23.,  24.],
                 [ 25.,  26.,  27.,  28.],
                 [ 29.,  30.,  31.,  32.]]],
      
      
               [[[ 33.,  34.,  35.,  36.],
                 [ 37.,  38.,  39.,  40.],
                 [ 41.,  42.,  43.,  44.],
                 [ 45.,  46.,  47.,  48.]],
      
                [[ 49.,  50.,  51.,  52.],
                 [ 53.,  54.,  55.,  56.],
                 [ 57.,  58.,  59.,  60.],
                 [ 61.,  62.,  63.,  64.]]]],
      
      
      
              [[[[ 65.,  66.,  67.,  68.],
                 [ 69.,  70.,  71.,  72.],
                 [ 73.,  74.,  75.,  76.],
                 [ 77.,  78.,  79.,  80.]],
      
                [[ 81.,  82.,  83.,  84.],
                 [ 85.,  86.,  87.,  88.],
                 [ 89.,  90.,  91.,  92.],
                 [ 93.,  94.,  95.,  96.]]],
      
      
               [[[ 97.,  98.,  99., 100.],
                 [101., 102., 103., 104.],
                 [105., 106., 107., 108.],
                 [109., 110., 111., 112.]],
      
                [[113., 114., 115., 116.],
                 [117., 118., 119., 120.],
                 [121., 122., 123., 124.],
                 [125., 126., 127., 128.]]]]])
      tensor(True)
      
      

      【讨论】:

        猜你喜欢
        • 2017-07-15
        • 1970-01-01
        • 2017-07-09
        • 2019-10-22
        • 2020-05-05
        • 2013-05-22
        • 2021-02-01
        • 2012-12-09
        • 2022-11-05
        相关资源
        最近更新 更多