【问题标题】:Keras TypeError when loading a model with a custom layer加载带有自定义层的模型时出现 Keras TypeError
【发布时间】:2021-01-05 06:46:26
【问题描述】:

我正在保存的模型中有一个自定义 Keras 层。我想加载这个模型。这是我用来这样做的代码:

self.model = load_model(path, custom_objects={'MyLayer': MyLayer, 'custom_loss_fn': custom_loss_fn})

这是我在模型中使用的自定义层:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super(MyLayer, self).__init__(units, **kwargs)
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, values):
        ...

    def get_config(self):
        config = super().get_config()
        config.update({
          'w1': self.W1,
          'w2': self.W2,
          'v': self.V,
        })
        return config

当我尝试加载模型时,出现以下错误:

TypeError: __init__() missing 1 required positional argument: 'units'

这是为什么呢?

更新

我正在使用 kera 的 ModelCheckpoint 回调保存模型。这可能是兼容性问题吗?

【问题讨论】:

  • 如果将units = self.units 放在init func 的顶部会怎样?
  • 你的意思是 self.units = units?
  • 是的,这能解决问题吗?
  • 并尝试将 super() 语句放在 init 函数的底部
  • 以上都不起作用。

标签: python tensorflow keras keras-layer


【解决方案1】:

尝试相应地修改init函数

def __init__(self, units, **kwargs):    
   self.units = units
   ...
   super(MyCustomLayer, self).__init__(**kwargs)

【讨论】:

    猜你喜欢
    • 2019-03-26
    • 1970-01-01
    • 2022-12-17
    • 1970-01-01
    • 2019-12-21
    • 1970-01-01
    • 2020-09-07
    • 1970-01-01
    • 2020-09-12
    相关资源
    最近更新 更多