【发布时间】:2021-01-24 19:30:58
【问题描述】:
PyTorch 函数torch.nn.functional.interpolate 包含多种上采样模式,例如:nearest、linear、bilinear、bicubic、trilinear、area。
area 上采样模式用于什么?
【问题讨论】:
-
欢迎来到 Stack Overflow!
标签: pytorch interpolation
PyTorch 函数torch.nn.functional.interpolate 包含多种上采样模式,例如:nearest、linear、bilinear、bicubic、trilinear、area。
area 上采样模式用于什么?
【问题讨论】:
标签: pytorch interpolation
查看源代码,area 插值相当于通过adaptive average pooling 调整张量大小。关于自适应平均池化的解释可以参考this question。因此area插值比上采样更适用于下采样。
【讨论】:
正如 jodag 所说,它正在使用自适应平均池来调整大小。虽然链接上的答案旨在解释什么是自适应平均池,但我觉得解释有点模糊。
TL;DR torch.nn.functional.interpolate 的 area 模式可能是想对图像进行下采样时最直观的方法之一。
您可以将其视为对原始图像应用平均低通滤波器(LPF),然后进行采样。在采样前应用 LPF 是为了防止下采样图像中出现潜在的aliasing。 锯齿会导致缩小图像中出现莫尔条纹。
它可能被称为“面积”,因为它(大致)在平均输入像素时保留了输入和输出形状之间的面积比。更具体地说,输出图像中的每个像素将是输入图像中相应区域的平均值,其中该区域的1/area 大致是输出图像面积与输入图像面积之比。
此外,带有mode = 'area' 的interpolate 函数调用源函数adaptie_avg_pool2d(用C++ 实现),它为输出张量中的每个像素分配输入计算区域内所有像素强度的平均值。该区域是按像素计算的,并且不同像素的大小可能会有所不同。它的计算方法是将输出像素的高度和宽度乘以输入和输出(按此顺序)高度和宽度(分别)之间的比率,然后取一次floor(用于区域的起始索引)和一次结果值的ceil(用于区域的结束索引)。
下面是对nn.AdaptiveAvgPool2d发生的事情的深入分析:
首先,如上所述,您可以在此处找到自适应平均池(C++ 中)的源代码:source
看看发生魔法的函数(或者至少是单帧 CPU 上的魔法)static void adaptive_avg_pool2d_single_out_frame,我们有 5 个嵌套循环,依次运行在通道维度、宽度、高度和主体内第三个循环发生了奇迹:
首先计算输入图像中用于计算当前像素值的区域(回想一下,我们有宽度和高度循环来遍历输出中的所有像素)。 这是怎么做到的?
使用简单的开始和结束索引计算高度和宽度,如下所示:floor((input_height/output_height) * current_output_pixel_height) 用于开始,ceil((input_height/output_height) * (current_output_pixel_height+1)) 用于宽度。
然后,所做的只是简单地平均该区域和当前通道中所有像素的强度,并将结果放入当前输出像素中。
我编写了一个简单的 Python sn-p,它以相同的方式(循环、幼稚)执行相同的操作并产生相同的结果。它采用张量a 并使用自适应平均池以两种方式调整a 的大小以塑造output_shape - 一次使用内置的nn.AdaptiveAvgPool2d,一次使用我在C++ 中将源函数翻译成Python:@987654338 @。内置函数的结果保存到b,我的翻译保存到b_hat。您可以看到结果是等效的(您可以进一步使用空间形状并验证这一点):
import torch
from math import floor, ceil
from torch import nn
a = torch.randn(1, 3, 15, 17)
out_shape = (10, 11)
b = nn.AdaptiveAvgPool2d(out_shape)(a)
b_hat = torch.zeros(b.shape)
for d in range(a.shape[1]):
for w in range(b_hat.shape[3]):
for h in range(b_hat.shape[2]):
startW = floor(w * a.shape[3] / out_shape[1])
endW = ceil((w + 1) * a.shape[3] / out_shape[1])
startH = floor(h * a.shape[2] / out_shape[0])
endH = ceil((h + 1) * a.shape[2] / out_shape[0])
b_hat[0, d, h, w] = torch.mean(a[0, d, startH: endH, startW: endW])
'''
Prints Mean Squared Error = 0 (or a very small number, due to precision error)
as both outputs are the same, proof of output equivalence:
'''
print(nn.MSELoss()(b_hat, b))
【讨论】: