【发布时间】:2019-12-04 18:12:07
【问题描述】:
您好,我正在尝试评估数据集 MNIST 的标准差和平均值,但我得到了错误的标准差值。这是我的代码:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
loader = torch.utils.data.DataLoader(datasets.MNIST(
'../data', train=True, download=True, transform=transform1),
batch_size=32,
num_workers=0,
shuffle=False)
mean = 0.
std = 0.
for images, _ in loader:
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1), -1)
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
mean /= len(loader.dataset)
std /= len(loader.dataset)
print("The mean is ", mean)
print("The standard deviation is ", std)
我的问题如下,我得到的平均值是 0.1307,标准差是 0.3015 而不是 0.3081。我想我的代码中有错误,但我看不到在哪里。
你能帮帮我吗?
非常感谢!
【问题讨论】:
标签: python python-3.x artificial-intelligence pytorch mnist