【问题标题】:How do I merge 2D Convolutions in PyTorch?如何在 PyTorch 中合并 2D 卷积?
【发布时间】:2019-10-12 19:38:48
【问题描述】:

从线性代数我们知道线性算子是结合的。

在深度学习领域,这个概念被用来证明在 NN 层之间引入非线性是合理的,以防止一种俗称linear lasagna,(reference) 的现象。

在信号处理中,这也导致了一个众所周知的技巧来优化内存和/或运行时要求 (reference)。

所以从不同的角度来看,合并卷积是一个非常有用的工具。如何用 PyTorch 实现?

【问题讨论】:

  • “从线性代数我们知道线性算子是交换的和结合的。”这不是真的,您只能假设关联性,交换性在特殊情况(同时对角化)之外很少见。
  • 糟糕!你是对的,我的意思是线性系统,而不是运算符。我会更正并澄清。 “线性”的概念不同,这令人困惑,所以我也会澄清一下。感谢您的评论!
  • 其实我刚删了。作为记录,这是我的参考:dspguide.com/ch5/5.htm。在那里,交换性被列为“系统线性”的属性。但是在这个问题中,那种术语只是令人困惑/误导并且不需要。再次感谢!

标签: pytorch linear-algebra convolution


【解决方案1】:

如果我们有y = x * a * b(其中* 表示卷积,a, b 是你的内核),我们可以定义c = a * b 使得y = x * c = x * a * b 如下:

import torch

def merge_conv_kernels(k1, k2):
    """
    :input k1: A tensor of shape ``(out1, in1, s1, s1)``
    :input k1: A tensor of shape ``(out2, in2, s2, s2)``
    :returns: A tensor of shape ``(out2, in1, s1+s2-1, s1+s2-1)``
      so that convolving with it equals convolving with k1 and
      then with k2.
    """
    padding = k2.shape[-1] - 1
    # Flip because this is actually correlation, and permute to adapt to BHCW
    k3 = torch.conv2d(k1.permute(1, 0, 2, 3), k2.flip(-1, -2),
                      padding=padding).permute(1, 0, 2, 3)
    return k3

为了说明等价性,本例将两个分别具有 900 和 5000 个参数的内核组合成一个具有 28 个参数的等效内核:

# Create 2 conv. kernels
out1, in1, s1 = (100, 1, 3)
out2, in2, s2 = (2, 100, 5)
kernel1 = torch.rand(out1, in1, s1, s1, dtype=torch.float64)
kernel2 = torch.rand(out2, in2, s2, s2, dtype=torch.float64)

# propagate a random tensor through them. Note that padding
# corresponds to the "full" mathematical operation (s-1)
b, c, h, w = 1, 1, 6, 6
x = torch.rand(b, c, h, w, dtype=torch.float64) * 10
c1 = torch.conv2d(x, kernel1, padding=s1 - 1)
c2 = torch.conv2d(c1, kernel2, padding=s2 - 1)

# check that the collapsed conv2d is same as c2:
kernel3 = merge_conv_kernels(kernel1, kernel2)
c3 = torch.conv2d(x, kernel3, padding=kernel3.shape[-1] - 1)
print(kernel3.shape)
print((c2 - c3).abs().sum() < 1e-5)

注意:等价是假设我们有无限的数值分辨率。我认为有关于堆叠许多低分辨率浮点线性运算并表明网络从数值误差中获利的研究,但我找不到它。任何参考表示赞赏!

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-09-20
    • 2019-06-26
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多