【问题标题】:TFF loading a pre-trained Keras modelTFF 加载预训练的 Keras 模型
【发布时间】:2020-05-13 22:57:06
【问题描述】:

我的目标是从 .hdf5 文件(它是 Keras 模型)加载基础模型,并继续使用联邦学习对其进行训练。以下是我为 FL 初始化基本模型的方法:

def model_fn():
    model = tf.keras.load_model(path/to/model.hdf5)
    return tff.learning.from_keras_model(model=model, 
                                         dummy_batch=db, 
                                         loss=loss, 
                                         metrics=metrics)

trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()

但是,生成的 state.model 权重似乎是随机初始化的,并且与我保存的模型不同。当我在任何联合训练之前评估模型的性能时,它作为一个随机初始化的模型执行:50% 的准确度。以下是我评估性能的方式:

def evaluate(state):
    keras_model = tf.keras.models.load_model(path/to/model.hdf5, compile=False)
    tff.learning.assign_weights_to_keras_model(keras_model, state.model)
    keras_model.compile(loss=loss, metrics=metrics)
    return keras_model.evaluate(features, values)

如何使用保存的模型权重初始化 tff 模型?

【问题讨论】:

    标签: tensorflow-federated


    【解决方案1】:

    是的,我认为 initialize 应该会重新运行初始化程序,并返回这个值。

    但是,有一种方法可以用 TFF 做这样的事情。 TFF 是强类型和功能性的——如果我们可以构造一个具有正确值的参数,它与上述联合平均过程所期望的类型相匹配,那么事情应该“正常工作”。所以这里的目标是构造一个满足这些要求的参数。

    您可以在这里查看FileCheckpointManager's load implementation 以获得一些灵感,但我认为您使用 Keras 的情况更简单。

    假设您像上面一样使用statemodel 您的 Keras 模型,这里有一个解包和重新打包所有内容的快捷方式——如 TFF 教程之一的this section 所示——也就是说, tff.learning.state_with_new_model_weights 的用法。如果您具有上述状态和模型(并且 TF 处于渴望模式),则以下内容应该适合您:

    state = tff.learning.state_with_new_model_weights(
        state,
        trainable_weights=[v.numpy() for v in model.trainable_weights],
        non_trainable_weights=[
            v.numpy() for v in model.non_trainable_weights
        ])
    

    这应该将模型的权重重新分配给state 对象的适当元素。

    【讨论】:

    • 谢谢,这似乎是一个充满希望的开始。但是您知道如何解决由此产生的错误:“目标结构的类型为''。但是输入结构是一个序列()长度为 6。Nest 不能保证将一个映射到另一个是安全的。"
    • 另外,我很困惑为什么我的原始方法不起作用。它似乎适用于本教程tensorflow.org/federated/tutorials/…
    • 你必须转换state
    • @Eliza 如何转换它?
    • state 在您的示例中是 tff.python.common_libs.anonymous_tuple.AnonymousTuple (IIRC) 的一个实例,它与 tf.convert_to_tensor 不兼容,正如 save_checkpoint 所需要的并在其文档字符串中声明的那样。 TFF 研究代码中经常使用的通用解决方案是引入一个 Python attr s 类,以便在返回状态后立即从匿名元组转换
    猜你喜欢
    • 2017-07-16
    • 2020-05-12
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-07-28
    • 1970-01-01
    • 1970-01-01
    • 2018-04-15
    相关资源
    最近更新 更多