【发布时间】:2021-01-05 06:46:26
【问题描述】:
我正在保存的模型中有一个自定义 Keras 层。我想加载这个模型。这是我用来这样做的代码:
self.model = load_model(path, custom_objects={'MyLayer': MyLayer, 'custom_loss_fn': custom_loss_fn})
这是我在模型中使用的自定义层:
class MyLayer(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super(MyLayer, self).__init__(units, **kwargs)
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, values):
...
def get_config(self):
config = super().get_config()
config.update({
'w1': self.W1,
'w2': self.W2,
'v': self.V,
})
return config
当我尝试加载模型时,出现以下错误:
TypeError: __init__() missing 1 required positional argument: 'units'
这是为什么呢?
更新
我正在使用 kera 的 ModelCheckpoint 回调保存模型。这可能是兼容性问题吗?
【问题讨论】:
-
如果将units = self.units 放在init func 的顶部会怎样?
-
你的意思是 self.units = units?
-
是的,这能解决问题吗?
-
并尝试将 super() 语句放在 init 函数的底部
-
以上都不起作用。
标签: python tensorflow keras keras-layer