【问题标题】:PyTorch - Convert CIFAR dataset to `TensorDataset`PyTorch - 将 CIFAR 数据集转换为“TensorDataset”
【发布时间】:2021-05-01 15:43:20
【问题描述】:

我在 CIFAR 数据集上训练 ResNet34。由于某种原因,我需要将数据集转换为TensorDataset。 我的解决方案基于此:https://stackoverflow.com/a/44475689/15072863 有一些差异(也许它们很关键,但我不明白为什么)。 看来我做的不对。

火车装载机:

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

train_ds = torchvision.datasets.CIFAR10('/files/', train=True, transform=transform_train, download=True)

xs, ys = [], []
for x, y in train_ds:
  xs.append(x)
  ys.append(y)

# 1) Standard Version
# cifar_train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers)

# 2) TensorDataset version, seems to be incorrect
cifar_tensor_ds = TensorDataset(torch.stack(xs), torch.tensor(ys, dtype=torch.long))
cifar_train_loader = DataLoader(cifar_tensor_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers)

我认为这并不重要,但测试加载器的定义与往常一样:

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar_test_loader = DataLoader(
  torchvision.datasets.CIFAR10('/files/', train=False, transform=transform_test, download=True),
  batch_size=batch_size_test, shuffle=False, num_workers=num_workers)

我知道我使用 TensorDataset 的方式有问题,因为;

  1. 使用TensorDataset,我实现了 100% 的训练准确率和 80% 的测试准确率
  2. 使用标准数据集,我实现了 99% 的训练准确度(从来没有 100%)和 90% 的测试准确度。

那么,我做错了什么?

P.S.:我的最终目标是根据类别将数据集分成 10 个数据集。有一个更好的方法吗?当然,我可以定义我的 DataSet 子类,但是手动拆分它并创建TensorDataset 似乎更简单。

【问题讨论】:

    标签: python neural-network pytorch dataset dataloader


    【解决方案1】:

    使用“标准”数据集时,每次加载图像时,都会对其应用随机变换(翻转 + 裁剪)。因此,几乎每个时代的每幅图像都是独一无二的,只能看到一次。所以你有 nb_epochs * len(dataset) 不同的输入。

    使用您的自定义数据集,您首先读取 CIFAR 数据集的所有图像(每个图像都经过随机变换),将它们全部存储,然后使用存储的张量作为您的训练输入。因此在每个时期,网络看到完全相同的输入

    由于网络已经能够通过随机变换达到很高的准确度,因此移除它会使其变得更加容易,从而进一步提高准确度

    哦,你绝对应该重新定义你自己的数据集子类。它甚至并不复杂,而且使用起来会容易得多。你只需要提取 10 个不同的数据集,要么通过手动移动文件夹中的图像,要么使用一些重新索引数组或类似的东西。不管怎样,你只需要做一次,所以没什么大不了的

    【讨论】:

    • 谢谢。我在帖子中指定了错误的数字:TensorDataset 的测试准确度较低。你的解释很有道理。
    猜你喜欢
    • 2019-08-30
    • 1970-01-01
    • 1970-01-01
    • 2019-07-20
    • 2019-12-09
    • 1970-01-01
    • 2021-02-09
    • 1970-01-01
    • 2015-11-13
    相关资源
    最近更新 更多