【问题标题】:early stopping in tensorflow object detection apitensorflow对象检测api中的早期停止
【发布时间】:2020-11-24 04:49:35
【问题描述】:

我正在尝试在 TF OD API 中实现提前停止。我用过这个code

这是我的 EarlyStoppingHook(它本质上只是上面代码的副本):

class EarlyStoppingHook(session_run_hook.SessionRunHook):
    """Hook that requests stop at a specified step."""

    def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                 mode='auto'):
        """
        """
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        self.max_wait = 0
        self.ind = 0
        if mode not in ['auto', 'min', 'max']:
            logging.warning('EarlyStopping mode %s is unknown, '
                            'fallback to auto mode.', mode, RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

        self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def begin(self):
        # Convert names to tensors if given
        graph = tf.get_default_graph()
        self.monitor = graph.as_graph_element(self.monitor)
        if isinstance(self.monitor, tf.Operation):
            self.monitor = self.monitor.outputs[0]

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return session_run_hook.SessionRunArgs(self.monitor)

    def after_run(self, run_context, run_values):
        self.ind += 1

        current = run_values.results

        if self.ind % 200 == 0:
          print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            if self.max_wait < self.wait:
              self.max_wait = self.wait
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                run_context.request_stop()

我使用这样的类:


early_stopping_hook = EarlyStoppingHook(
      monitor='total_loss', 
      patience=2000)

train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])

我不明白什么是total_loss?这是 val 损失还是 train 损失?另外我不明白这些损失('total_loss'、'loss_1'、'loss_2')是在哪里定义的。

【问题讨论】:

    标签: python tensorflow tensorflow-estimator


    【解决方案1】:

    所以,这对我有用

    from matplotlib import pyplot as plt
    import numpy as np
    
    import collections
    import os 
    
    _EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
    
    def _summaries(eval_dir):
      """Yields `tensorflow.Event` protos from event files in the eval dir.
      Args:
        eval_dir: Directory containing summary files with eval metrics.
      Yields:
        `tensorflow.Event` object read from the event files.
      """
      if tf.compat.v1.gfile.Exists(eval_dir):
        for event_file in tf.compat.v1.gfile.Glob(
            os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
          for event in tf.compat.v1.train.summary_iterator(event_file):
            yield event
    
    def read_eval_metrics(eval_dir):
      """Helper to read eval metrics from eval summary files.
      Args:
        eval_dir: Directory containing summary files with eval metrics.
      Returns:
        A `dict` with global steps mapping to `dict` of metric names and values.
      """
      eval_metrics_dict = collections.defaultdict(dict)
      for event in _summaries(eval_dir):
        if not event.HasField('summary'):
          continue
        metrics = {}
        for value in event.summary.value:
          if value.HasField('simple_value'):
            metrics[value.tag] = value.simple_value
        if metrics:
          eval_metrics_dict[event.step].update(metrics)
      return collections.OrderedDict(
          sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
      
    met_dict_2 = read_eval_metrics('/content/gdrive2/My Drive/models/retinanet/eval_0')
    x = []
    y = []
    for k, v in met_dict_2.items():
        x.append(k)
        y.append(v['Loss/total_loss'])
    

    read_eval_metrics 函数返回字典,其中键是迭代次数,值是不同的指标,以及在该评估步骤中的损失计算机。但是您也可以将此功能用于火车事件文件。你只需要改变路径。

    返回字典中的一个键值对示例。

    (4988, {'DetectionBoxes_Precision/Precision@.50IOU': 0.12053315341472626,
                   'DetectionBoxes_Precision/mAP': 0.060865387320518494,
                   'DetectionBoxes_Precision/mAP (large)': 0.07213596999645233,
                   'DetectionBoxes_Precision/mAP (medium)': 0.062120337039232254,
                   'DetectionBoxes_Precision/mAP (small)': 0.02642354555428028,
                   'DetectionBoxes_Precision/mAP@.50IOU': 0.11469704657793045,
                   'DetectionBoxes_Precision/mAP@.75IOU': 0.06001879647374153,
                   'DetectionBoxes_Recall/AR@1': 0.13470394909381866,
                   'DetectionBoxes_Recall/AR@10': 0.20102562010288239,
                   'DetectionBoxes_Recall/AR@100': 0.2040158212184906,
                   'DetectionBoxes_Recall/AR@100 (large)': 0.2639017701148987,
                   'DetectionBoxes_Recall/AR@100 (medium)': 0.20173722505569458,
                   'DetectionBoxes_Recall/AR@100 (small)': 0.10018187761306763,
                   'Loss/classification_loss': 1.0127471685409546,
                   'Loss/localization_loss': 0.3542810380458832,
                   'Loss/regularization_loss': 0.708609938621521,
                   'Loss/total_loss': 2.0756208896636963,
                   'learning_rate': 0.0006235376931726933,
                   'loss': 2.0756208896636963})
    

    所以我最终在 EarlyStoppingHook 中将监视器参数设置为“Loss/total_loss”而不是“total_loss”。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-05-03
      • 2017-12-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-01-17
      • 2021-02-01
      相关资源
      最近更新 更多