【问题标题】:Wrong value of standard deviation标准差的错误值
【发布时间】:2019-12-04 18:12:07
【问题描述】:

您好,我正在尝试评估数据集 MNIST 的标准差和平均值,但我得到了错误的标准差值。这是我的代码:

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F

loader = torch.utils.data.DataLoader(datasets.MNIST(
'../data', train=True, download=True, transform=transform1),
                     batch_size=32,
                     num_workers=0,
                     shuffle=False)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

print("The mean is ", mean)
print("The standard deviation is ", std)

我的问题如下,我得到的平均值是 0.1307,标准差是 0.3015 而不是 0.3081。我想我的代码中有错误,但我看不到在哪里。

你能帮帮我吗?

非常感谢!

【问题讨论】:

    标签: python python-3.x artificial-intelligence pytorch mnist


    【解决方案1】:

    这里的微小差异来自这样一个事实,即在您的代码中计算均值和标准差的方式不同,并且通常将它们用于归一化。

    在这里,您所做的是计算每个图像中所有像素的每个批次的平均值和标准差,然后取它们的平均值。您最终会得到 0.3015 的值。

    现在,如果您要计算整个数据集的均值和标准差,您将不会使用相同的均值,最终会找到 0.3081 的值。

    【讨论】:

      【解决方案2】:

      torch.std 使用批次均值作为计算的一部分,因此它与在整个数据集上使用torch.std 不同,因为这将使用不同的均值。我们可以使用下面的well known expression 进行方差来得到想要的结果

      Var(X) = E[X**2] - E[X]**2

      mean = 0.
      mean_square = 0.
      samples = 0
      for images, _ in loader:
          batch_samples = images.size(0) 
          images = images.view(batch_samples, images.size(1), -1)
          mean += images.mean(2).sum(0)
          mean_square += (images**2).mean(2).sum(0)
          samples += images.size(2) * images.size(0)
      
      mean /= len(loader.dataset)
      mean_square /= len(loader.dataset)
      
      # extra scale factor for unbias std estimate (it's effectively 1.0)
      scale = samples / (samples - 1)
      std = torch.sqrt((mean_square - mean**2) * scale)
      
      print("The mean is ", mean)
      print("The standard deviation is ", std)
      

      当然,在 torchvision MNIST 数据集的特殊情况下,您可以直接计算均值和标准差...

      mean = torch.mean(loader.dataset.data.float() / 255.0)
      std = torch.std(loader.dataset.data.float() / 255.0)
      

      【讨论】: