【发布时间】:2021-09-16 19:14:16
【问题描述】:
假设有一个模型以h5 文件的形式给出,也就是说,我无法更改构建模型架构的代码:
from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model
inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)
现在我想删除training=True 部分,即我想将BatchNormalization 视为附加到没有此标志的模型。
我目前的尝试如下:
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
model.predict(np.asarray([[1, 2, 3, 4]]))
但 model.predict 调用失败并出现以下错误(我使用的是 TensorFlow 2.5.0):
ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements. Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].
如何解决/解决这个问题?
(当使用node.call_kwargs["training"] = False 而不是del node.call_kwargs["training"] 时model.predict 不会崩溃,但它只是表现得好像什么都没有改变,即,修改后的标志被忽略。)
【问题讨论】:
-
我不确定这是否可能。 h5 文件已经规范化,不能以不同的方式进行非规范化和重新规范化。无论如何,您是否尝试将其设置为 false 而不是完全删除参数?
-
在model.layers中尝试层:
-
@AlmogAtNailo 当使用
node.call_kwargs["training"] = False时model.predict不会崩溃,但它只是表现得好像没有任何改变,即,修改后的标志被忽略。但是谢谢你的评论。我已经相应地调整了我的问题。
标签: python tensorflow keras tensorflow2.0 tf.keras