【问题标题】:deepcopy class containing keras models包含 keras 模型的 deepcopy 类
【发布时间】:2018-10-05 21:52:09
【问题描述】:

在我的 python 脚本中,我创建了一个类,其中包含keras 模型,如下所示:

from keras.layers import Input, Activation, Dense
from keras.models import Model


class Klass:

    def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):

        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.optimizer = optimizer
        self.a = a
        self.b = b

        self.__build_nn()

    def __build_nn(self):

        inputs = Input(shape=(self.input_dims,))
        net = inputs
        for h_dim in self.hidden_dims:
            net = Dense(h_dim, kernel_initializer='he_uniform')(net)
            net = Activation("relu")(net)

        outputs = Dense(self.output_dims)(net)
        outputs = Activation("linear")(outputs)
        self.nn1 = Model(inputs=inputs, outputs=outputs)
        self.nn2 = Model(inputs=inputs, outputs=outputs)
        self.nn1.compile(optimizer=self.optimizer, loss='mean_squared_error')
        self.nn2.compile(optimizer=self.optimizer, loss='mean_squared_error')

创建Klass 实例后,我想对其进行深层复制:

import copy
obj = Klass(10, 10, (20, 20), Adam(), 1, 2)
obj_dc = copy.deepcopy(obj)

但是,这会引发 TypeError: can't pickle _thread.RLock objects。我很确定该错误与类对象中的 keras 模型有关,因为我能够在没有 keras 模型的情况下获得类似类的深层副本。

不幸的是,我无法在互联网上找到解决方案,因为大多数关于深度复制 keras 模型的问题都试图克隆 keras 模型,如 here

那么,如何获得包含keras 模型的类的深层副本?

编辑

这三个问题(123)在不同情况下都提到了类似的错误。然而,那里提供的解决方案不适用于我的情况。

编辑 2

按照 cmets 中的建议,我在类中添加了 copy 方法。这会是一个可行的解决方案吗?

class Klass:

    def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):

        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.optimizer = optimizer
        self.a = a
        self.b = b

        self.__build_nn()

    # [...]

    def copy(self):

        new = Klass(self.input_dims, self.output_dims, self.hidden_dims,
                    self.optimizer, self.a, self.b)
        new.nn1.set_weights(self.nn1.get_weights())
        new.nn2.set_weights(self.nn2.get_weights())

        return new

【问题讨论】:

  • 听起来很棘手,因为 Keras 在后台做了很多事情,具体取决于您的环境。我将为Klass 编写一个copy 方法,在该方法中我创建另一个Klass 对象并将权重复制到新对象上。这能满足您的需求吗?
  • @KotaMori 感谢您的建议。添加的copy 方法(编辑2)会是实现预期结果的“pythonic”方式吗?
  • 使用keras.models.clone_model 会更好,但我认为复制权重没有太大区别。 github.com/keras-team/keras/issues/1765。虽然不知道这是不是pythonic,但好像也没有那么多其他的选择。

标签: python tensorflow keras pickle deep-copy


【解决方案1】:

在 cmets 中解决:为 Klass 添加了一个 copy 方法,该方法将权重从旧的 Klass 实例复制到新创建的实例。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-10-25
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多