【问题标题】:What's the recommended way to serialize `tf.Module`s?序列化 `tf.Module`s 的推荐方法是什么?
【发布时间】:2020-05-04 14:30:13
【问题描述】:

我有一个 tf.Module 类,其中包含一个(不可选择的)tf.keras.Model 作为子模块。我想知道在这种情况下序列化tf.Module 的推荐方法是什么?

我考虑了两种方法:

  1. 使用类似于tf.keras.Model.save 的内容。我希望tf.Modules 能够像tf.Model.save 一样保存嵌套模块。但是,tf.Module 并没有实现这样的功能。
  2. 腌制,这将是序列化tf.Module 的一种简单方法,但我不能这样做,因为tf.keras.Model 是不可腌制的。

这是当前失败的示例代码:

import pickle

import tensorflow as tf


class TestModule(tf.Module):
    def __init__(self, model):
        self.model = model


def main():
    x = tf.keras.layers.Input((3, ))
    y = tf.keras.layers.Dense(5)(x)
    # Note, model *is not* picklable.
    model = tf.keras.Model(x, y)

    _ = model(tf.random.uniform((1, 3)))

    module_1 = TestModule(model)
    module_2 = pickle.loads(pickle.dumps(module_1))

    for variable_1, variable_2 in zip(module_1.model.trainable_variables,
                                      module_2.model.trainable_variables):
        tf.debugging.assert_equal(variable_1, variable_2)


if __name__ == '__main__':
    main()

我应该为每个tf.Module 编写自定义泡菜功能(例如__{get,set}state__)还是应该创建keras.Models 拥有的类似.save 方法?

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0


    【解决方案1】:

    您可以使用Saved Model Format 保存自定义tf.Module 子类。

    以下适用于 Tensorflow 2.1:

    import tensorflow as tf
    
    class TestModule(tf.Module):
        def __init__(self, model):
            self.model = model
    
    
    x = tf.keras.layers.Input((3, ))
    y = tf.keras.layers.Dense(5)(x)
    model = tf.keras.Model(x, y)
    module_1 = TestModule(model)
    
    tf.saved_model.save(module_1, "./foo")
    

    要加载回:
    imported = tf.saved_model.load("foo")

    断言
    module_1 == imported(或类似的)将引发AssertionError,因为在加载后我们正在处理不同的 Tensorflow 对象。然而,我们可以迭代模型的权重并逐元素比较它们:

    original_weights = module_1.model.weights
    imported_weights = imported.model.variables.weights
    
    for weight_idx, _ in enumerate(original_weights):
      assert (
          original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
          ).all()
    

    【讨论】:

    • 这似乎是个不错的选择。我会稍等片刻,看看是否有其他有用的答案出现。不过,还有一个关于此的问题:似乎imported 缺少trainable_variables 属性,即使它存在于module_1 中。任何想法为什么会这样?我得到一个error AttributeError: '_UserObject' object has no attribute 'trainable_variables'。我为此打开了一个 github 问题,因为这对我来说似乎是一个错误:github.com/tensorflow/tensorflow/issues/36021
    • 确实如此。 trainable_variables 仅存在于 importedmodel 子部分中。虽然,我无法回答为什么,因为我不知道那么多序列化细节。
    猜你喜欢
    • 1970-01-01
    • 2011-05-14
    • 1970-01-01
    • 2010-10-19
    • 2021-02-22
    • 2021-08-07
    • 2013-04-04
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多