【问题标题】:Why there's no .train() method in this Pytorch official tutorial?为什么这个 Pytorch 官方教程中没有 .train() 方法?
【发布时间】:2019-11-16 23:07:52
【问题描述】:

所以,我刚刚学习了 Pytorch,他们说你必须通过 .train() 方法将 NN 置于训练模式,然后在推断 .eval() 模式时。我正在阅读本教程,根本没有 .train() 。这是为什么呢?

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

【问题讨论】:

标签: pytorch


【解决方案1】:

.train() 将模块的self.training 属性设置为True.eval() 将其设置为 False

source for nn.Module 中可以看出,此属性最初设置为True。因此,除非您在开始训练之前致电eval(),否则您不一定(根据当前实施)需要致电train()但是重要的是,模块在训练时应该处于self.training=True 状态,所以无论如何这样做可能是一个好习惯。

此外,目前,只有一些模块(例如 dropout 和 batchnorm)会根据 self.training 属性更改其行为。因此,如果您不使用这些特定模块,则不必一定调用 .train().eval(),但同样,无论如何,这样做可能是一种很好的做法,可以让您的代码面向未来.

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-01-10
    • 2012-10-07
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多