【问题标题】:How to calculate the mean and the std of cifar10 data如何计算 cifar10 数据的均值和标准差
【发布时间】:2021-06-15 02:01:13
【问题描述】:

Pytorch 使用以下值作为 cifar10 数据的平均值和标准: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

我需要理解计算它背后的概念,因为这个数据是 3 通道图像,我不明白什么是加和除以什么等等。 另外,如果有人可以分享计算平均值和标准差的代码,将非常感激。

【问题讨论】:

    标签: pytorch torchvision


    【解决方案1】:

    0.5 值只是三个通道 (r,g,b) 上 cifar10 均值和标准值的近似值。 cifar10 训练集的精确值是

    • 意思是:0.49139968, 0.48215827 ,0.44653124
    • 标准:0.24703233 0.24348505 0.26158768

    您可以使用以下脚本计算这些:

    import torch
    import numpy
    import torchvision.datasets as datasets
    from torchvision import transforms
    
    cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
    
    imgs = [item[0] for item in cifar_trainset] # item[0] and item[1] are image and its label
    imgs = torch.stack(imgs, dim=0).numpy()
    
    # calculate mean over each channel (r,g,b)
    mean_r = imgs[:,0,:,:].mean()
    mean_g = imgs[:,1,:,:].mean()
    mean_b = imgs[:,2,:,:].mean()
    print(mean_r,mean_g,mean_b)
    
    # calculate std over each channel (r,g,b)
    std_r = imgs[:,0,:,:].std()
    std_g = imgs[:,1,:,:].std()
    std_b = imgs[:,2,:,:].std()
    print(std_r,std_g,std_b)
    

    此外,您可能会发现 herehere 的均值和标准值相同

    【讨论】:

      【解决方案2】:

      另一种方式

      from  torchvision import datasets
      
      cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True  )
      data = cifar_trainset.data / 255 # data is numpy array
      
      mean = data.mean(axis = (0,1,2)) 
      std = data.std(axis = (0,1,2))
      print(f"Mean : {mean}   STD: {std}") #Mean : [0.491 0.482 0.446]   STD: [0.247 0.243 0.261]
      

      【讨论】:

        猜你喜欢
        • 2021-10-09
        • 2014-03-21
        • 1970-01-01
        • 1970-01-01
        • 2021-04-08
        • 1970-01-01
        • 2021-06-03
        相关资源
        最近更新 更多