【发布时间】:2020-11-24 04:49:35
【问题描述】:
我正在尝试在 TF OD API 中实现提前停止。我用过这个code。
这是我的 EarlyStoppingHook(它本质上只是上面代码的副本):
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
"""
"""
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
self.max_wait = 0
self.ind = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]
def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
self.ind += 1
current = run_values.results
if self.ind % 200 == 0:
print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
if self.max_wait < self.wait:
self.max_wait = self.wait
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()
我使用这样的类:
early_stopping_hook = EarlyStoppingHook(
monitor='total_loss',
patience=2000)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])
我不明白什么是total_loss?这是 val 损失还是 train 损失?另外我不明白这些损失('total_loss'、'loss_1'、'loss_2')是在哪里定义的。
【问题讨论】:
标签: python tensorflow tensorflow-estimator