【发布时间】:2020-05-04 14:30:13
【问题描述】:
我有一个 tf.Module 类,其中包含一个(不可选择的)tf.keras.Model 作为子模块。我想知道在这种情况下序列化tf.Module 的推荐方法是什么?
我考虑了两种方法:
- 使用类似于
tf.keras.Model.save的内容。我希望tf.Modules 能够像tf.Model.save一样保存嵌套模块。但是,tf.Module并没有实现这样的功能。 - 腌制,这将是序列化
tf.Module的一种简单方法,但我不能这样做,因为tf.keras.Model是不可腌制的。
这是当前失败的示例代码:
import pickle
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(model)
module_2 = pickle.loads(pickle.dumps(module_1))
for variable_1, variable_2 in zip(module_1.model.trainable_variables,
module_2.model.trainable_variables):
tf.debugging.assert_equal(variable_1, variable_2)
if __name__ == '__main__':
main()
我应该为每个tf.Module 编写自定义泡菜功能(例如__{get,set}state__)还是应该创建keras.Models 拥有的类似.save 方法?
【问题讨论】:
标签: python tensorflow keras tensorflow2.0