【问题标题】:output prediction of pytorch lightning modelpytorch闪电模型的输出预测
【发布时间】:2021-04-24 16:52:53
【问题描述】:

这可能是一个非常简单的问题。我刚开始使用 PyTorch 闪电,不知道如何在训练后接收模型的输出。

我对 y_train 和 y_test 作为某种数组(稍后步骤中的 PyTorch 张量或 NumPy 数组)的预测感兴趣,以便使用不同的脚本在标签旁边绘制。

dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)

在我的闪电模块中,我有以下功能:

def __init__(self, random_inputs):

def forward(self, x):

def train_dataloader(self):
    
def val_dataloader(self):

def training_step(self, batch, batch_nb):

def training_epoch_end(self, outputs):

def validation_step(self, batch, batch_nb):

def validation_epoch_end(self, outputs):

def configure_optimizers(self):

我需要一个特定的预测函数还是有什么我看不到的已经实现的方法?

【问题讨论】:

    标签: python pytorch pytorch-lightning


    【解决方案1】:

    我不同意这些答案:OP 的问题似乎集中在他应该如何使用经过闪电训练的模型来获得一般情况下的预测,而不是针对训练管道中的特定步骤。在这种情况下,用户不需要靠近 Trainer 对象 - 这些不打算用于一般预测,因此上面的答案鼓励反模式(每次我们都随身携带一个 trainer 对象)想要做一些预测)给任何将来阅读这些答案的人。

    不使用trainer,我们可以直接从已定义的闪电模块获得预测:如果我有闪电模块model = Net(...) 的(训练过的)实例,然后使用该模型获得输入@987654323 的预测@ 只需调用 model(x) 即可实现(只要 forward 方法已在 Lightning 模块上实现/覆盖 - 这是必需的)。

    相比之下,Trainer.predict() 通常不是使用经过训练的模型获得预测的预期方法。 Trainer API 为您的 LightningModule 提供了 tunefittest 的方法,作为训练管道的一部分,在我看来,predict 方法是为单独数据加载器上的临时预测提供的,作为更少的“标准”训练步骤。

    OP 的问题(我需要一个特定的预测函数还是有任何我看不到的已经实现的方法?)暗示他们不熟悉forward() 方法在 PyTorch 中的工作方式,但询问是否已经有一种他们看不到的预测方法。因此,完整的答案需要进一步解释forward() 方法适合预测过程的位置:

    model(x) 起作用的原因是因为闪电模块是torch.nn.Module 的子类,它们实现了一个名为__call__() 的魔术方法,这意味着我们可以像调用函数一样调用类实例。 __call__() 又调用 forward(),这就是为什么我们需要在 Lightning 模块中覆盖该方法。

    注意。因为forward只是我们使用model(x)时调用的逻辑的一部分,所以除非您有特定的偏离原因,否则始终建议使用model(x)而不是model.forward(x)进行预测。

    【讨论】:

    • 很高兴您指出了如何直接运行网络,因为从 Pytorch 开始时,Lightning 从未使用过 Pytorch,直接隐藏了底层机制。我认为在某些情况下使用 Trainer 类进行预测仍然是合理的,因为它处理将模型和数据放到 GPU 上,它可以调用某些钩子,为什么要重新发明轮子?这不是反模式,将类重命名为 Commander 并且您的大部分论点都是无效的。我仍然认为你指出它很好,但是反模式太强大了。
    【解决方案2】:

    您也可以使用predict 方法。这是文档中的示例。 https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

    class LitMNISTDreamer(LightningModule):
    
        def forward(self, z):
            imgs = self.decoder(z)
            return imgs
    
        def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
            return self(batch)
    
    
    model = LitMNISTDreamer()
    trainer.predict(model, datamodule) 
    

    【讨论】:

    • predict 方法好像是同时添加的。我只是很困惑它以前不可用。
    • 是的,他们似乎非常擅长添加新内容
    • 使用trainer.predict()和使用model()有什么区别?第一个选项是否自动将调用包装在 eval 模式和 no_grad 中?
    • 训练器将你的模型和输入放到显卡上,限制batch的数量(如果设置了,见训练器__init__args),执行分布式计算等等。
    【解决方案3】:

    训练器有一个test 函数。您可能想查看 pytorch-lightning 的原始文档以了解更多详细信息:https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#testing

    【讨论】:

    • 太棒了。我自己怎么没有找到这个。很可能是因为我得到的所有错误。但一切都解决了。
    • 我不相信 .test 允许您返回张量(其目的主要是通过 logging API 收集日志 - 目前不接受列表或 torch/np.arrays )。所以 .predict() 似乎是前进的方向。
    【解决方案4】:

    您可以通过两种方式尝试预测:

    1. 照常执行批量预测。
    test_dataset = Dataset(test_tensor)
    test_generator = torch.utils.data.DataLoader(test_dataset, **test_params)
    
    mynet.eval()
    batch = next(iter(test_generator))
    with torch.no_grad():
        predictions_single_batch = mynet(**unpacked_batch)
    
    1. 实例化一个新的Trainer 对象。 Trainer 的predict API 允许你传递任意的DataLoader
    test_dataset = Dataset(test_tensor)
    test_generator = torch.utils.data.DataLoader(test_dataset, **test_params)
    
    predictor = pl.Trainer(gpus=1)
    predictions_all_batches = predictor.predict(mynet, dataloaders=test_generator)
    

     我注意到,在第二种情况下,PytorchLightning 会处理诸如将张量和模型移动到(而不是离开)GPU 之类的事情,这与其执行潜力保持一致分布式预测。它也不返回任何梯度附加的损失值,这有助于消除编写样板代码(如with torch.no_grad())的需要。

    【讨论】:

      猜你喜欢
      • 2023-03-10
      • 1970-01-01
      • 2022-08-07
      • 1970-01-01
      • 2021-04-06
      • 2021-11-08
      • 2022-12-06
      • 2020-03-08
      • 2021-12-18
      相关资源
      最近更新 更多