【发布时间】:2019-01-14 17:24:45
【问题描述】:
我正在尝试使用 Tensorflow 从 GAN 中保存生成器模型。我使用的模型有几个批规范层。当我保存权重时,只有运行全局变量初始化程序才能成功恢复它们,我不应该这样做,因为正在恢复所有变量。如果我在恢复之前运行全局变量初始化程序,当我使用加载的权重运行推理并为批量规范参数设置 is_training=False 时,模型的性能非常差。但是,如果 is_training=True,则模型按预期执行。这种行为应该完全相反。
为了节省重量,我这样做:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]
g_saver = tf.train.Saver(g_vars)
... train model ...
g_saver.save(sess, "weights/generator/gen.ckpt")
当我恢复权重时,我使用相同的模型定义并执行以下操作:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
g_saver = tf.train.Saver(g_vars)
g_saver.restore(sess, "./weights/generator/gen.ckpt")
您是否需要执行特殊程序来计算批次规范权重?我是否缺少一些变量集合?
编辑:
我使用以下方法定义批处理规范层:
conv1_norm = tf.contrib.layers.batch_norm(conv1, is_training=training
我发现将 variables_collections=["g_batch_norm_non_trainable"] 添加到 batch_norm 函数中,然后做
g_vars = list(set([var for var in t_vars if 'g_' in var.name] + tf.get_collection("g_batch_norm_non_trainable")))
有效,但对于本应简单的减肥指令来说,这似乎相当复杂。
【问题讨论】:
-
如何定义批规范层?能否在您的问题中添加这部分或一批规范层的示例?
-
我刚刚添加了一个编辑,如果有帮助请告诉我!
标签: python tensorflow deep-learning