【问题标题】:Cannot load the model in Keras无法在 Keras 中加载模型
【发布时间】:2021-11-30 13:00:51
【问题描述】:

我已经完成了 NN 的训练,我保存了模型并再次加载它。我收到了这个错误。

TypeError: __init__() missing 1 required positional argument: 'projection_dim'

我无法理解问题的原因和问题所在。 这是我的代码,其中包含 projection_dim

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.prj_dim = projection_dim
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        a = self.prj_dim
        return encoded

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patches': self.num_patches,
            'projection': self.projection,
            'position_embedding': self.position_embedding,
            'prj_dim': self.prj_dim   
            })
        return config

这是任何有兴趣的人的 google colab notebook 的链接。 https://colab.research.google.com/drive/1LPi_xPe6kFV1eNTKWwHq1onJsDXep_Mj?usp=sharing

【问题讨论】:

    标签: python keras deep-learning tensorflow2.0


    【解决方案1】:

    您需要在get_config 方法中将prj_dim 更改为projection_dim,因为它需要位置参数projection_dim。这边——

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patches': self.num_patches,
            'projection': self.projection,
            'position_embedding': self.position_embedding,
            'projection_dim': self.prj_dim   
            })
        return config
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-11-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-05-03
      相关资源
      最近更新 更多