要提取(重叠)补丁并重建输入形状,我们可以使用torch.nn.functional.unfold 和逆运算torch.nn.functional.fold。这些方法只处理 4D 张量或 2D 图像,但是您可以使用这些方法一次处理一个维度。
几点说明:
-
这种方式需要 pytorch 的折叠/展开方法,遗憾的是我还没有在 TF api 中找到类似的方法。
-
我们可以通过两种方式提取补丁,它们的输出是相同的。这些方法称为extract_patches_3d 和extract_patches_3ds,其中X 是维数。后者使用 torch.Tensor.unfold() 并且代码行数更少。 (输出是一样的,只是不能使用膨胀)
-
extract_patches_Xd 和combine_patches_Xd 方法是inverse 方法,combiner 逐步反转提取器的步骤。
-
代码行后面是说明维度的注释,例如 (B, C, D, H, W)。使用以下内容:
-
B: 批量大小
-
C:频道
-
D: 深度维度
-
H: 高度尺寸
-
W:宽度尺寸
-
x_dim_in:在提取方法中,这是维度x中的输入像素数。在组合方法中,这是维度x的滑动窗口个数。
-
x_dim_out:在提取方法中,这是维度x的滑动窗口数。在组合方法中,这是维度x的输出像素数。
-
我有一个public notebook to try out the code
-
get_dim_blocks() 方法是pytorch docs website 上给出的函数,用于计算卷积层的输出形状。
-
请注意,如果您有重叠的补丁并将它们组合起来,重叠的元素将被求和。如果您想再次获得初始输入,有一种方法。
- 使用
torch.ones_like(patches_tensor) 创建与补丁相似大小的张量。
- 将补丁组合成具有相同输出形状的完整图像。 (这会为重叠元素创建一个计数器)。
- 将组合图像除以组合图像,这应该会反转任何元素的双重求和。
(3D):
我们需要使用 2 个
fold 和unfold,我们首先将fold 应用于D 维度,并通过将内核设置为1、填充为0、步幅为1 来保持W 和H 不变并膨胀到 1。在我们查看张量并折叠 H 和 W 维度之后。展开反向进行,从H 和W 开始,然后是D。
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))
输出(3D)
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 1.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 1., 2.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 3., 4.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 4., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 1.],
[ 0., 5.]]]],
...,
[[[[124., 0.],
[128., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 125.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[125., 126.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[126., 127.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[127., 128.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[128., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 8., 16., 24., 32.],
[ 40., 48., 56., 64.],
[ 72., 80., 88., 96.],
[ 104., 112., 120., 128.]],
[[ 136., 144., 152., 160.],
[ 168., 176., 184., 192.],
[ 200., 208., 216., 224.],
[ 232., 240., 248., 256.]]],
[[[ 264., 272., 280., 288.],
[ 296., 304., 312., 320.],
[ 328., 336., 344., 352.],
[ 360., 368., 376., 384.]],
[[ 392., 400., 408., 416.],
[ 424., 432., 440., 448.],
[ 456., 464., 472., 480.],
[ 488., 496., 504., 512.]]]],
[[[[ 520., 528., 536., 544.],
[ 552., 560., 568., 576.],
[ 584., 592., 600., 608.],
[ 616., 624., 632., 640.]],
[[ 648., 656., 664., 672.],
[ 680., 688., 696., 704.],
[ 712., 720., 728., 736.],
[ 744., 752., 760., 768.]]],
[[[ 776., 784., 792., 800.],
[ 808., 816., 824., 832.],
[ 840., 848., 856., 864.],
[ 872., 880., 888., 896.]],
[[ 904., 912., 920., 928.],
[ 936., 944., 952., 960.],
[ 968., 976., 984., 992.],
[1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)