【问题标题】:Why torch does train using mini-batches?为什么 torch 使用小批量进行训练?
【发布时间】:2022-01-26 18:45:06
【问题描述】:

我目前正在努力了解如何通过pytorch 训练模型。虽然这我看到了一个非常有趣的功能:传递给训练数据 --- 是一个小批量。例如。

有一个来自官方pytorch的代码片段web-site

...

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

...
...

for inputs, labels in dataloaders[phase]:
    inputs = inputs.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()

    with torch.set_grad_enabled(phase == 'train'):
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

...

根据这段代码,传递给模型的输入是一个小批量。

我已经绑定了一些关于此的信息,但没有成功。

但我真的很好奇,这是某种提升(并行运行等)还是必要的事情。那么,您是否愿意帮我弄清楚并告诉我,为什么会有一个 mini-batch 传递给 train 函数?

注意不会拒绝论文的链接[微笑:)]。

【问题讨论】:

  • 你很少训练整个数据集(由于计算限制),因此你训练的是批量数据
  • 那么模型如何在批量训练的同时学习处理一张图像来计算预测?它一般是如何运行的?
  • 使用小批量可以加快收敛速度​​。 This 可能会有所帮助。
  • 除了数学和学术问题外,输入整个数据集将需要巨大的内存,这通常不适合图像数据类型(在您的示例中)。例如,一张图像的大小为 224x224x3,一千张图像将有 1,505,28,000 个像素,更不用说在训练模型时需要计算梯度的中间值。一个标准数据集通常有超过 50,000 个样本。
  • 那我们为什么不想将一个单一的图像传递给网络,计算梯度并从下一个图像重新开始呢?这样每次迭代我们只缓存当前图像的数据,而不是整个数据集。迭代中的梯度记忆不是大致相同吗?为什么我们必须传递 X 个图像而不是 1 个?

标签: python deep-learning pytorch torch torchvision


【解决方案1】:

这样想;假设您想了解狗和猫之间的区别,而您以前从未见过它们。 这些批次将是我们一次向您展示 10 张狗和猫的图像。在说 4x10 图像(4 批)之后,您可以相当快地了解猫和狗的一些差异,但如果到目前为止,您看到的所有狗都是大型犬,那么您当然会有偏见,因此您可能会将所有小型犬归类为猫。经过足够多的批次后,您将学习,然后忘记不同的功能,因为您不会一次看到所有功能,但重要的是,您开始快速学习一些东西。

另一方面,假设我们向您展示了 100 张图片而不是 10 张图片。您将需要更长的时间查看所有图片并将它们相互比较,但您将了解“一次性”的不同之处可以这么说。

不管怎样;当您处理完这些图像(作为批次或整个数据集)后,我可以向您展示一张图像,您可以告诉我它是狗还是猫,即使您已经从多张图像中学习。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2020-10-29
    • 2017-11-30
    • 1970-01-01
    • 2021-06-17
    • 2016-11-05
    • 1970-01-01
    • 2015-09-28
    相关资源
    最近更新 更多