【问题标题】:How can I convert a trained Tensorflow model to Keras?如何将训练有素的 Tensorflow 模型转换为 Keras?
【发布时间】:2017-11-11 23:02:24
【问题描述】:

我有一个训练有素的 Tensorflow 模型和权重向量,它们已分别导出到 protobuf 和权重文件。

如何将这些转换为 Keras 可以使用的 JSON 或 YAML 和 HDF5 文件?

我有 Tensorflow 模型的代码,因此将 tf.Session 转换为 keras 模型并将其保存在代码中也是可以接受的。

【问题讨论】:

    标签: tensorflow keras


    【解决方案1】:

    我认为keras中的回调也是一种解决方案。

    ckpt文件可以通过TF保存:

    saver = tf.train.Saver()
    saver.save(sess, checkpoint_name)
    

    要在 Keras 中加载检查点,您需要一个回调类,如下所示:

    class RestoreCkptCallback(keras.callbacks.Callback):
        def __init__(self, pretrained_file):
            self.pretrained_file = pretrained_file
            self.sess = keras.backend.get_session()
            self.saver = tf.train.Saver()
        def on_train_begin(self, logs=None):
            if self.pretrian_model_path:
                self.saver.restore(self.sess, self.pretrian_model_path)
                print('load weights: OK.')
    

    然后在你的 keras 脚本中:

     model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
     restore_ckpt_callback = RestoreCkptCallback(pretrian_model_path='./XXXX.ckpt') 
     model.fit(x_train, y_train, batch_size=128, epochs=20, callbacks=[restore_ckpt_callback])
    

    这样就好了。 我认为它很容易实现,希望对您有所帮助。

    【讨论】:

    • 嗨,MyCallbacks 是什么?
    • 嗨@Austin,MyCallbacks 是RestoreCkptCallback。我已经更正了我的帖子。谢谢你的提醒!
    • 但是,这需要你在 keras 中编写模型,不是吗?
    【解决方案2】:

    keras 的创建者 Francois Chollet 在 04/2017 中表示:“您不能将任意 TensorFlow 检查点转换为 Keras 模型。但是,您可以做的是构建一个等效的 Keras 模型,然后将权重加载到这个 Keras 模型中" ,见https://github.com/keras-team/keras/issues/5273。据我所知,这并没有改变。

    一个小例子:

    首先,您可以像这样提取张量流检查点的权重

    PATH_REL_META = r'checkpoint1.meta'
        
    # start tensorflow session
    with tf.Session() as sess:
        
        # import graph
        saver = tf.train.import_meta_graph(PATH_REL_META)
        
        # load weights for graph
        saver.restore(sess, PATH_REL_META[:-5])
            
        # get all global variables (including model variables)
        vars_global = tf.global_variables()
        
        # get their name and value and put them into dictionary
        sess.as_default()
        model_vars = {}
        for var in vars_global:
            try:
                model_vars[var.name] = var.eval()
            except:
                print("For var={}, an exception occurred".format(var.name))
    

    它也可能用于导出 tensorflow 模型以用于 tensorboard,请参阅https://stackoverflow.com/a/43569991/2135504

    其次,您像往常一样构建您的 keras 模型并通过“model.compile”完成它。请注意,您需要按名称定义每个层,然后将其添加到模型中,例如

    layer_1 = keras.layers.Conv2D(6, (7,7), activation='relu', input_shape=(48,48,1))
    net.add(layer_1)
    ...
    net.compile(...)
    

    第三,您可以使用 tensorflow 值设置权重,例如

    layer_1.set_weights([model_vars['conv7x7x1_1/kernel:0'], model_vars['conv7x7x1_1/bias:0']])
    

    【讨论】:

    • 如何处理 batch_norm 层,因为它们有 4 个参数并且似乎会导致问题...
    • @ADA:不是 100% 肯定,但如果你用最小的代码示例提出新问题,我或其他人可以看看。
    • 感谢我发了一个帖子。我很想知道我缺少什么
    【解决方案3】:

    目前,Tensorflow 或 Keras 没有直接内置支持将冻结模型或检查点文件转换为 hdf5 格式。

    但既然你提到你有 Tensorflow 模型的代码,你将不得不在 Keras 中重写该模型的代码。然后,您必须从检查点文件中读取变量的值,并使用layer.load_weights(weights) 方法将其分配给 Keras 模型。

    除了这种方法之外,我建议您直接在 Keras 中进行培训,因为它声称 Keras' optimizers are 5-10% times faster than Tensorflow's optimizers。另一种方法是使用 tf.contrib.keras 模块在 Tensorflow 中编写代码,然后直接将文件保存为 hdf5 格式。

    【讨论】:

      【解决方案4】:

      不确定这是否是您正在寻找的,但我碰巧对 TF 1.2 中新发布的 keras 支持做了同样的事情。您可以在此处找到有关 API 的更多信息:https://www.tensorflow.org/api_docs/python/tf/contrib/keras

      为了节省您一点时间,我还发现我必须包含如下所示的 keras 模块,并将附加的 python.keras 附加到 API 文档中显示的内容中。

      从 tensorflow.contrib.keras.python.keras.models 导入顺序

      希望能帮助您到达您想去的地方。基本上,一旦集成,您就可以像往常一样处理模型/权重导出。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2017-11-17
        • 2021-01-01
        • 2019-01-20
        • 1970-01-01
        • 2020-05-21
        • 2016-11-01
        • 2020-09-01
        相关资源
        最近更新 更多