我一直在查看源代码,但没有发现与此相关的FLAG。
但是,在https://github.com/tensorflow/models/blob/master/research/object_detection/model_main.py 的文件model_main.py
可以找到如下主函数定义:
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
...
我们的想法是以类似的方式修改它,例如以下方式:
config_proto = tf.ConfigProto()
config_proto.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir, session_config=config_proto)
因此,添加config_proto 并更改config,但保持所有其他内容相同。
另外,allow_growth 使程序可以根据需要使用尽可能多的 GPU 内存。因此,根据您的 GPU,您最终可能会吃掉所有内存。在这种情况下,您可能需要使用
config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
它定义了要使用的内存比例。
希望这有帮助。
如果您不想修改文件,似乎应该打开一个问题,因为我没有看到任何 FLAG。除非 FLAG
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
表示与此相关的东西。但我不这么认为,因为在model_lib.py 中似乎与训练、评估和推断配置有关,而不是 GPU 使用配置。