这可能只是一个 hack,但我找到了解决问题的方法。
对象检测器需要安装tf_slim 包。在tf_slim 包中,有一个名为learning.py 的模块。
完整的路径可能如下所示:/usr/local/lib/python3.6/site-packages/tf_slim/learning.py
在learning.py 中,从第 764 行开始,代码如下所示:
try:
while not sv.should_stop():
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
我写了一个小的if 语句来检查total_loss 的最后五个值的最大值,如果低于某个阈值(在本例中为3),则使should_stopTrue。如下所示:
try:
total_loss_list = []
while not sv.should_stop():
total_loss, should_stop = train_step_fn(sess, train_op, global_step,
train_step_kwargs)
total_loss_list.append(total_loss)
if len(total_loss_list) > 5:
if max(total_loss_list[-5:]) < 3:
should_stop = True
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
如果损失值连续五步低于 3,则训练停止。这样做的缺点是,tf_slim 的包分发必须更改。并且每次处理新的对象检测问题时,这个阈值损失值都会发生变化。更好的方法是使用您提供阈值损失值的配置文件。但我暂时停在这里。
如果其他人有更好的解决方案,请分享。
我希望这可以帮助别人。谢谢!