【问题标题】:How to freeze a device specific saved model?如何冻结设备特定的保存模型?
【发布时间】:2020-07-09 09:05:20
【问题描述】:

我需要冻结保存的模型以供服务,但某些保存的模型是特定于设备的,如何解决这个问题?

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    sess.run(tf.tables_initializer())

    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
    inference_graph_def=tf.get_default_graph().as_graph_def()

    for node in inference_graph_def.node:
        node.device = ''

    frozen_graph_path = os.path.join(frozen_dir, 'frozen_inference_graph.pb')
    output_keys = ['ToInt64', 'ToInt32', 'while/Exit_5']
    output_node_names = ','.join(["%s/%s" % ('NmtModel', output_key) for output_key in output_keys])
    _ = freeze_graph.freeze_graph(
            input_graph=inference_graph_def,
            input_saver=None,
            input_binary=True,
            input_saved_model_dir=saved_model_dir,
            input_checkpoint=None,
            output_node_names=output_node_names,
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=frozen_graph_path,
            clear_devices=True,
            initializer_nodes='')
    logging.info("export frozen_inference_graph.pb success!!!")
Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

Caused by op u'NmtModel/transpose/Rank', defined at:
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 55, in <module>
    absl_app.run(main)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 50, in main
    saved_model2frozen(FLAGS.saved_model_dir, FLAGS.frozen_dir)
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 16, in saved_model2frozen
    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 197, in load
    return loader.load(sess, tags, import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 350, in load
    **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 278, in load_graph
    meta_graph_def, import_scope=import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1696, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 806, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 442, in import_graph_def
    _ProcessNewOps(graph)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 234, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3440, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3299, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

似乎某些模型是在多 GPU 中训练的,但导出到保存的模型时没有明确的设备信息。

【问题讨论】:

    标签: tensorflow tensorflow-serving


    【解决方案1】:

    我不确定是否有更好的方法来解决这个问题,但一种可能性是简单地编辑保存的型号信息以删除设备规格。下面的 sn-p 应该可以做到,尽管您应该在使用之前备份保存的模型以防万一。

    from pathlib import Path
    import tensorflow as tf
    from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
    
    # Read the model file
    model_path = saved_model_dir
    graph_path = Path(model_path, 'saved_model.pb')
    sm = SavedModel()
    with graph_path.open('rb') as f:
        sm.ParseFromString(f.read())
    # Go through graph and functions to remove every device specification
    for mg in sm.meta_graphs:
        for node in mg.graph_def.node:
            node.device = ''
        for func in mg.graph_def.library.function:
            for node in func.node_def:
                node.device = ''
    # Write over file
    with graph_path.open('wb') as f:
        f.write(sm.SerializeToString())
    
    # Now load model as usual
    # ...
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-07-13
      • 1970-01-01
      • 2018-10-27
      • 2020-02-07
      • 2018-11-08
      • 1970-01-01
      相关资源
      最近更新 更多