【发布时间】:2021-09-14 20:31:38
【问题描述】:
假设我们正在使用CIFAR-10 dataset,并且我们想要应用一些数据增强并另外规范化张量。这是一些可重现的代码
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
trafo = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomCrop(size = (32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean = (0.0, 0.0, 0.0), std = (1.0, 1.0, 1.0))]
)
cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo, target_transform = None, download = True)
到目前为止,我选择的规范化对张量没有任何作用,因为我将 mean 和 std 分别放在 0 和 1 上。根据torchvision.transforms.Normalize 的文档,提供的均值和标准差适用于输入的每个通道。然而,问题是我无法计算每个通道的平均值,因为一些随机翻转和裁剪平均值。因此,我的想法大致如下
trafo_1 = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomCrop(size = (32, 32)),
transforms.ToTensor()
)
cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo_1, target_transform = None, download = True)
现在我可以计算输入的每个通道的平均值,然后我想再次对张量进行归一化。 但是,我不能简单地使用 transforms.Normalize(),因为 cifar10_full 不再是原始数据集,而是我该如何继续?(一种解决方案是简单地修复随机生成器的种子,即使用torch.manual_seed(0),但我现在想避免这种情况......)
【问题讨论】:
标签: deep-learning neural-network pytorch