【发布时间】:2019-11-16 23:07:52
【问题描述】:
所以,我刚刚学习了 Pytorch,他们说你必须通过 .train() 方法将 NN 置于训练模式,然后在推断 .eval() 模式时。我正在阅读本教程,根本没有 .train() 。这是为什么呢?
【问题讨论】:
-
请不要同时在 SO 和 Discuss PyTorch 上发布相同的问题。至少,等待一天在其中一个中得到答案。
标签: pytorch
所以,我刚刚学习了 Pytorch,他们说你必须通过 .train() 方法将 NN 置于训练模式,然后在推断 .eval() 模式时。我正在阅读本教程,根本没有 .train() 。这是为什么呢?
【问题讨论】:
标签: pytorch
.train() 将模块的self.training 属性设置为True。 .eval() 将其设置为 False。
从source for nn.Module 中可以看出,此属性最初设置为True。因此,除非您在开始训练之前致电eval(),否则您不一定(根据当前实施)需要致电train()。 但是重要的是,模块在训练时应该处于self.training=True 状态,所以无论如何这样做可能是一个好习惯。
此外,目前,只有一些模块(例如 dropout 和 batchnorm)会根据 self.training 属性更改其行为。因此,如果您不使用这些特定模块,则不必一定调用 .train() 和 .eval(),但同样,无论如何,这样做可能是一种很好的做法,可以让您的代码面向未来.
【讨论】: