【问题标题】:how to properly saving loaded h5 model to pb with TF2如何使用 TF2 将加载的 h5 模型正确保存到 pb
【发布时间】:2019-06-18 10:21:02
【问题描述】:

我加载了一个已保存的 h5 模型并希望将模型保存为 pb。 模型在训练期间使用tf.keras.callbacks.ModelCheckpoint 回调函数保存。

TF 版本:2.0.0a
编辑:同样的问题也与 2.0.0-beta1

我保存 pb 的步骤:

  1. 我先设置K.set_learning_phase(0)
  2. 然后我用tf.keras.models.load_model 加载模型
  3. 然后,我定义了freeze_session() 函数。
  4. (可选我编译模型)
  5. 然后将freeze_session() 函数与tf.keras.backend.get_session 一起使用

错误我得到,无论是否编译:

AttributeError: 模块 'tensorflow.python.keras.api._v2.keras.backend' 没有属性“get_session”

我的问题:

  1. TF2 不再有get_session 了吗? (我知道tf.contrib.saved_model.save_keras_model 已经不存在了,我也尝试了tf.saved_model.save,但没有成功)

  2. 或者get_session 仅在我实际训练模型并且仅加载 h5 不起作用时才起作用 编辑:同样对于新训练的会话,没有可用的 get_session。

    • 如果是这样,我将如何将未经培训的 h5 转换为 pb?有好的教程吗?

感谢您的帮助


更新

自从 TF2.x 正式发布以来,图形/会话概念发生了变化。应该使用savedmodel api。 您可以将tf.compat.v1.disable_eager_execution() 与TF2.x 一起使用,它将生成一个pb 文件。但是,我不确定它是哪种 pb 文件类型,因为保存的模型组合从 TF1 更改为 TF2。我会继续挖掘。

【问题讨论】:

    标签: tensorflow keras protocol-buffers tensorflow2.0 keras-2


    【解决方案1】:

    我确实将模型从 h5 模型保存到 pb

    import logging
    import tensorflow as tf
    from tensorflow.compat.v1 import graph_util
    from tensorflow.python.keras import backend as K
    from tensorflow import keras
    
    # necessary !!!
    tf.compat.v1.disable_eager_execution()
    
    h5_path = '/path/to/model.h5'
    model = keras.models.load_model(h5_path)
    model.summary()
    # save pb
    with K.get_session() as sess:
        output_names = [out.op.name for out in model.outputs]
        input_graph_def = sess.graph.as_graph_def()
        for node in input_graph_def.node:
            node.device = ""
        graph = graph_util.remove_training_nodes(input_graph_def)
        graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)
        tf.io.write_graph(graph_frozen, '/path/to/pb/model.pb', as_text=False)
    logging.info("save pb successfully!")
    

    我使用 TF2 转换模型如下:

    1. 在训练时将keras.callbacks.ModelCheckpoint(save_weights_only=True) 传递给model.fit 并保存checkpoint
    2. 训练后,self.model.load_weights(self.checkpoint_path)加载checkpoint
    3. self.model.save(h5_path, overwrite=True, include_optimizer=False)另存为h5;
    4. h5 转换为pb 就像上面一样;

    【讨论】:

    • 嗨,write_graph 的路径和名称是分开的。它应该是 'tf.io.write_graph(graph_frozen, '/path/to/pb, model.pb', as_text=False)'
    【解决方案2】:

    我想知道同样的事情,因为我正在尝试使用 get_session() 和 set_session() 来释放 GPU 内存。这些功能似乎缺失了aren't in the TF2.0 Keras documentation。我想这与 Tensorflow 切换到渴望执行有关,因为不再需要直接会话访问。

    【讨论】:

    • 嗨,詹姆斯,我也这么认为。我想知道他们是否会将其中的一部分带回不同的图书馆或建立新的方式。
    • 从 tensorflow.compat.v1.keras.backend 导入 get_session
    【解决方案3】:

    使用

    from tensorflow.compat.v1.keras.backend import get_session
    

    在 keras 2 和张量流 2.2 中

    然后调用

    import logging
    import tensorflow as tf
    from tensorflow.compat.v1 import graph_util
    from tensorflow.python.keras import backend as K
    from tensorflow import keras
    from tensorflow.compat.v1.keras.backend import get_session
    
    # necessary !!!
    tf.compat.v1.disable_eager_execution()
    
    h5_path = '/path/to/model.h5'
    model = keras.models.load_model(h5_path)
    model.summary()
    # save pb
    with get_session() as sess:
        output_names = [out.op.name for out in model.outputs]
        input_graph_def = sess.graph.as_graph_def()
        for node in input_graph_def.node:
            node.device = ""
        graph = graph_util.remove_training_nodes(input_graph_def)
        graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)
        tf.io.write_graph(graph_frozen, '/path/to/pb/model.pb', as_text=False)
    logging.info("save pb successfully!")
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-09-29
      • 1970-01-01
      • 2018-01-09
      • 2018-10-25
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-09-17
      相关资源
      最近更新 更多