【问题标题】:How to remove training=True from the inbound nodes of a layer in an existing model?如何从现有模型中图层的入站节点中删除 training=True?
【发布时间】:2021-09-16 19:14:16
【问题描述】:

假设有一个模型以h5 文件的形式给出,也就是说,我无法更改构建模型架构的代码:

from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)

现在我想删除training=True 部分,即我想将BatchNormalization 视为附加到没有此标志的模型。

我目前的尝试如下:

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

model.predict(np.asarray([[1, 2, 3, 4]]))

model.predict 调用失败并出现以下错误(我使用的是 TensorFlow 2.5.0):

ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements.  Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].

如何解决/解决这个问题?

(当使用node.call_kwargs["training"] = False 而不是del node.call_kwargs["training"]model.predict 不会崩溃,但它只是表现得好像什么都没有改变,即,修改后的标志被忽略。)

【问题讨论】:

  • 我不确定这是否可能。 h5 文件已经规范化,不能以不同的方式进行非规范化和重新规范化。无论如何,您是否尝试将其设置为 false 而不是完全删除参数?
  • 在model.layers中尝试层:
  • @AlmogAtNailo 当使用node.call_kwargs["training"] = Falsemodel.predict 不会崩溃,但它只是表现得好像没有任何改变,即,修改后的标志被忽略。但是谢谢你的评论。我已经相应地调整了我的问题。

标签: python tensorflow keras tensorflow2.0 tf.keras


【解决方案1】:

我发现,修改call_kwargs帮助后再次保存并重新加载模型。

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

# Removing training=True
for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')

model.predict(np.asarray([[1, 2, 3, 4]]))

一切都很好。 :)

【讨论】:

    【解决方案2】:

    你试过了吗

    for layer in model.layers:
        layer.trainable=False
    

    【讨论】:

    • "trainable": false 表示该层在训练期间被冻结(有助于微调/transfew-learning)。它独立于"training": true,这与某些层类型(Dropout/BatchNormalization)的行为方式有关。下面是一个例子,展示了这两个东西的自治性:gist.github.com/Dobiasd/d9a8a24f22e1acc2db8fa1de757cfe9d
    猜你喜欢
    • 1970-01-01
    • 2019-12-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-07-13
    • 1970-01-01
    相关资源
    最近更新 更多