【问题标题】:How to extract loss and accuracy from logger by each epoch in pytorch lightning?如何在 pytorch 闪电中按每个时期从记录器中提取损失和准确性?
【发布时间】:2021-11-15 11:48:29
【问题描述】:

我想提取所有数据来制作绘图,而不是使用 tensorboard。我的理解是,自从 tensorboard 绘制折线图以来,所有丢失和准确性的日志都存储在定义的目录中。

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

但是,我想知道如何从 pytorch 闪电中的记录器中提取所有日志。接下来是训练部分的代码示例。

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)

我已经确认trainer.logger.log_dir 返回了似乎保存日志的目录,trainer.logger.log_metrics 返回了<bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>

trainer.logged_metrics 只返回最后一个 epoch 的日志,比如

{'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185}

你知道如何解决这个问题吗?

【问题讨论】:

    标签: logging pytorch tensorboard pytorch-lightning


    【解决方案1】:

    Lightning 不会自行存储所有日志。它所做的只是将它们logger 实例中,然后记录器决定做什么。

    检索所有记录的指标的最佳方法是使用自定义回调:

    class MetricTracker(Callback):
    
      def __init__(self):
        self.collection = []
    
      def on_validation_batch_end(trainer, module, outputs, ...):
        vacc = outputs['val_acc'] # you can access them here
        self.collection.append(vacc) # track them
    
      def on_validation_epoch_end(trainer, module):
        elogs = trainer.logged_metrics # access it here
        self.collection.append(elogs)
        # do whatever is needed
    

    然后您可以从回调实例访问所有记录的内容

    cb = MatricTracker()
    Trainer(callbacks=[cb])
    
    cb.collection # do you plotting and stuff
    

    【讨论】:

      【解决方案2】:

      接受的答案在根本上没有错误,但不遵循 Pytorch-Lightning 的官方(当前)指南。

      这里建议:https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger

      建议写这样的类:

      from pytorch_lightning.utilities import rank_zero_only
      from pytorch_lightning.loggers import LightningLoggerBase
      from pytorch_lightning.loggers.base import rank_zero_experiment
      
      
      class MyLogger(LightningLoggerBase):
          @property
          def name(self):
              return "MyLogger"
      
          @property
          @rank_zero_experiment
          def experiment(self):
              # Return the experiment object associated with this logger.
              pass
      
          @property
          def version(self):
              # Return the experiment version, int or str.
              return "0.1"
      
          @rank_zero_only
          def log_hyperparams(self, params):
              # params is an argparse.Namespace
              # your code to record hyperparameters goes here
              pass
      
          @rank_zero_only
          def log_metrics(self, metrics, step):
              # metrics is a dictionary of metric names and values
              # your code to record metrics goes here
              pass
      
          @rank_zero_only
          def save(self):
              # Optional. Any code necessary to save logger data goes here
              # If you implement this, remember to call `super().save()`
              # at the start of the method (important for aggregation of metrics)
              super().save()
      
          @rank_zero_only
          def finalize(self, status):
              # Optional. Any code that needs to be run after training
              # finishes goes here
              pass
      

      通过查看LightningLoggerBase 类的内部,可以看到一些可以被覆盖的功能建议。

      这是我的一个简约记录器。它没有被高度优化,但将是一个很好的第一枪。如果我改进它,我会编辑。

      import collections
      
      from pytorch_lightning.loggers import LightningLoggerBase
      from pytorch_lightning.loggers.base import rank_zero_experiment
      from pytorch_lightning.utilities import rank_zero_only
      
      class History_dict(LightningLoggerBase):
          def __init__(self):
              super().__init__()
      
              self.history = collections.defaultdict(list) # copy not necessary here  
              # The defaultdict in contrast will simply create any items that you try to access
      
          @property
          def name(self):
              return "Logger_custom_plot"
      
          @property
          def version(self):
              return "1.0"
      
          @property
          @rank_zero_experiment
          def experiment(self):
              # Return the experiment object associated with this logger.
              pass
      
      @rank_zero_only
      def log_metrics(self, metrics, step):
          # metrics is a dictionary of metric names and values
          # your code to record metrics goes here
          for metric_name, metric_value in metrics.items():
              if metric_name != 'epoch':
                  self.history[metric_name].append(metric_value)
              else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
                  if (not len(self.history['epoch']) or    # len == 0:
                      not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
                      self.history['epoch'].append(metric_value)
                  else:
                      pass
          return
      
          def log_hyperparams(self, params):
              pass
      

      【讨论】:

        猜你喜欢
        • 2022-07-14
        • 2022-06-27
        • 2020-11-07
        • 2018-09-29
        • 2022-01-21
        • 1970-01-01
        • 2022-01-07
        • 1970-01-01
        相关资源
        最近更新 更多