【问题标题】:Loading custom CTC layer from h5 file in Keras从 Keras 中的 h5 文件加载自定义 CTC 层
【发布时间】:2020-07-10 06:03:49
【问题描述】:

我有一个这样的 CTCLayer 类:

class CTCLayer(layers.Layer):
def __init__(self, name=None):
    super().__init__(name=name)
    self.loss_fn = keras.backend.ctc_batch_cost


def call(self, y_true, y_pred):
    # Compute the training-time loss value and add it
    # to the layer using `self.add_loss()`.
    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

    input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

    loss = self.loss_fn(y_true, y_pred, input_length, label_length)
    self.add_loss(loss)

    # At test time, just return the computed predictions
    return y_pred

我训练了我的模型,将其保存到 model.h5 文件并通过以下方式加载:

model_load = tf.keras.models.load_model('model.h5', custom_objects={'CTCLayer': CTCLayer})

它抛出 init() got an unexpected keyword argument 'trainable' 错误。

由于我不想再次训练我的模型(时间限制),是否有任何解决方法可以加载模型而无需在 CTCLayer 类中添加 get_config()?

如果没有,我应该如何修改类中的 get_config()?

【问题讨论】:

    标签: tensorflow keras ctc


    【解决方案1】:

    这应该可行:

        class CTCLayer(layers.Layer):
        def __init__(self, name=None):
            def __init__(self, name=None, **kwargs):
            super(CTCLayer, self).__init__(name=name, **kwargs)
    

    【讨论】:

      猜你喜欢
      • 2019-12-21
      • 1970-01-01
      • 2021-04-06
      • 2019-08-03
      • 2021-10-27
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-08-02
      相关资源
      最近更新 更多