【发布时间】:2020-10-29 04:22:52
【问题描述】:
我想在LightningModule 的validation_epoch_end 方法中创建一个新张量。来自官方docs(第48页)指出我们应该避免直接调用.cuda()或.to(device):
没有 .cuda() 或 .to() 调用。 . . Lightning 会为您完成这些工作。
我们鼓励使用type_as 方法转移到正确的设备。
new_x = new_x.type_as(x.type())
但是,在validation_epoch_end 步骤中,我没有任何张量可以从(通过type_as 方法)以干净的方式复制设备。
我的问题是,如果我想在这个方法中创建一个新的张量并将其转移到模型所在的设备上,我该怎么办?
我唯一能想到的就是在outputs 字典中找到一个张量,但感觉有点乱:
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
output = self(self.__test_input.type_as(avg_loss))
有什么干净的方法可以实现吗?
【问题讨论】:
标签: python pytorch pytorch-lightning