【问题标题】:How to pickle Keras custom layer?如何腌制 Keras 自定义层?
【发布时间】:2019-01-07 08:30:42
【问题描述】:

我写了一个自定义层类扩展了层类,然后我想腌制历史以进行进一步分析,但是当我从文件中重新加载腌制对象时,python会引发错误:

未知层:注意力。

那么,我该如何解决呢?

我都尝试过get_config__getstate____setstate__,但都失败了。我只想腌制keras的历史,而不是模型,所以请不要告诉我带有custom_object参数的保存模型方法。

【问题讨论】:

  • 您能否添加可重现的代码来说明您的方法为何不起作用?

标签: python keras pickle


【解决方案1】:

出现此问题是因为在转储历史时,它无法转储完整模型。所以在加载的时候,找不到自定义类。

我注意到keras.callbacks.History 对象有一个属性model,它的不完整转储是导致此问题的原因。

你说:

我只想腌制 keras 的历史,而不是模型

所以以下是一种解决方法:

hist = model.fit(X, Y, ...)
hist.model = None

只需将model 属性设置为None,就可以成功转储和加载历史对象!

以下是 MVCE:

from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten, Layer
import keras.backend as K
import numpy as np
import pickle

# MyLayer from https://keras.io/layers/writing-your-own-keras-layers/
class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=(3,3), input_shape=(28,28,3), activation='sigmoid'))
model.add(Flatten())
model.add(MyLayer(10))
model.add(Dense(3, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

model.summary()

X = np.random.randn(64, 28, 28, 3)
Y = np.random.randint(0, high=2, size=(64,1))

hist = model.fit(X, Y, batch_size=8)

hist.model = None

with open('hist.pkl', 'wb') as f:
    pickle.dump(hist, f)

with open('hist.pkl', 'rb') as f:
    hist_reloaded = pickle.load(f)

print(hist.history)
print(hist_reloaded.history)

输出:

{'acc': [0.484375], 'loss': [6.140302091836929]}

{'acc': [0.484375], 'loss': [6.140302091836929]}

附:如果想用自定义层保存 keras 模型,this 应该会有所帮助。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2018-06-26
    • 2022-06-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-03-27
    • 1970-01-01
    相关资源
    最近更新 更多