【问题标题】:PyTorch: How to normalize a tensor when the image is cropped randomly?PyTorch:如何在随机裁剪图像时对张量进行归一化?
【发布时间】:2021-09-14 20:31:38
【问题描述】:

假设我们正在使用CIFAR-10 dataset,并且我们想要应用一些数据增强并另外规范化张量。这是一些可重现的代码

from torchvision import transforms, datasets
import matplotlib.pyplot as plt
trafo = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor(), 
                            transforms.Normalize(mean = (0.0, 0.0, 0.0), std = (1.0, 1.0, 1.0))]
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo, target_transform = None, download = True)

到目前为止,我选择的规范化对张量没有任何作用,因为我将 meanstd 分别放在 01 上。根据torchvision.transforms.Normalize 的文档,提供的均值和标准差适用于输入的每个通道。然而,问题是我无法计算每个通道的平均值,因为一些随机翻转和裁剪平均值。因此,我的想法大致如下

trafo_1 = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor() 
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo_1, target_transform = None, download = True)

现在我可以计算输入的每个通道的平均值,然后我想再次对张量进行归一化。 但是,我不能简单地使用 transforms.Normalize(),因为 cifar10_full 不再是原始数据集,而是我该如何继续?(一种解决方案是简单地修复随机生成器的种子,即使用torch.manual_seed(0),但我现在想避免这种情况......)

【问题讨论】:

    标签: deep-learning neural-network pytorch


    【解决方案1】:

    平均值和标准差不是针对每个张量,而是来自整个数据集。你想要做什么并不重要,你只需要一个足够好的尺度来表示整个数据,没有确切的平均值或标准你会得到,这些都是随机操作,只需使用平均值和标准从实际数据来看,这几乎是标准的。

    首先,尝试计算数据集的均值和标准差(尝试随机抽样),并将其用于归一化。

    # Calculate the mean, std of the complete dataset
    import glob
    import cv2
    import numpy as np 
    import tqdm
    import random
    
    # calculating 3 channel mean and std for image dataset
    
    means = np.array([0, 0, 0], dtype=np.float32)
    stds = np.array([0, 0, 0], dtype=np.float32)
    total_images = 0
    randomly_sample = 5000
    for f in tqdm.tqdm(random.sample(glob.glob("dataset_path/**.jpg", recursive = True), randomly_sample)):
        img = cv2.imread(f)
        means += img.mean(axis=(0,1))
        stds += img.std(axis=(0,1))
        total_images += 1
    means = means / (total_images * 255.)
    stds = stds / (total_images * 255.)
    print("Total images: ", total_images)
    print("Means: ", means)
    print("Stds: ", stds)
    

    只是一个简单的场景,您是否认为在实际测试或推理中您的图像也会以这种方式增强,可能不会,您将拥有与干净版本数据的均值和标准非常匹配的干净图像,所以除非您想应用 TTA,否则计算均值和标准差是没有用的(您可以随机抽取少量样本)。

    如果您也想应用 TTA,那么您可以继续对图像进行一些增强,进行随机抽样并获取这些图像的均值和标准差。

    【讨论】:

    • 嗨。所以你说的(也显示在代码中)是简单地省略填充等来找出平均值和标准差的值,对吧?但是按照我建议的方式去做不是更有意义吗?因为我输入 NN 的张量是那些已经被填充的,等等。
    • 你想要做什么并不重要,你只需要一个足够好的尺度来表示整个数据,你不会得到精确的平均值或标准,这些都是随机的操作,只需使用来自实际数据的均值和标准差,这几乎是标准。
    • 只是一个简单的场景,您是否认为在实际测试或推断中您的图像也会以这种方式增强,可能不会,您将获得与干净版本的均值和标准非常匹配的干净图像的数据,所以计算均值和标准差是没有用的(你可以随机抽取几个样本),除非你想应用 TTA。
    • 好的,我明白你的意思,但是“TTA”是什么意思?我以前从未听说过。
    • 我知道,这就是为什么我猜你不需要增广的均值和标准差,TTA 指的是测试时间增广。
    猜你喜欢
    • 1970-01-01
    • 2012-11-02
    • 2018-06-22
    • 1970-01-01
    • 2018-05-04
    • 1970-01-01
    • 2019-07-10
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多