【发布时间】:2026-01-09 05:20:03
【问题描述】:
我想在create_keras_model() 内部加载一个预训练网络
所以我写了这个:
def create_keras_model():
baseModel = tf.keras.models.load_model(model_path, compile=False)
headModel = baseModel.output
model_output = tf.keras.layers.Dense(3, activation="softmax", name="output")(headModel)
model = tf.keras.Model(inputs=baseModel.input, outputs=model_output)
return model
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(keras_model, input_spec = input_spec, loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()])
@tff.tf_computation
def get_weights_from_disk():
keras_model = create_keras_model()
return keras_model
@tff.federated_computation
def server_init():
# There may be state other than weights that needs to get returned from here,
# as in the implementation of build_federated_averaging_process.
return tff.federated_eval(get_weights_from_disk(), tff.SERVER)
old_iterproc = tff.learning.build_federated_averaging_process(model_fn=model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001))
new_iterproc = tff.templates.IterativeProcess(intialize_fn=server_init,
next_fn=old_iterproc.next)
state = new_iterproc.initialize()
【问题讨论】:
标签: tensorflow tensorflow-federated