【问题标题】:Saving Tensorflow model that uses BatchNorm保存使用 BatchNorm 的 TensorFlow 模型
【发布时间】:2019-01-14 17:24:45
【问题描述】:

我正在尝试使用 Tensorflow 从 GAN 中保存生成器模型。我使用的模型有几个批规范层。当我保存权重时,只有运行全局变量初始化程序才能成功恢复它们,我不应该这样做,因为正在恢复所有变量。如果我在恢复之前运行全局变量初始化程序,当我使用加载的权重运行推理并为批量规范参数设置 is_training=False 时,模型的性能非常差。但是,如果 is_training=True,则模型按预期执行。这种行为应该完全相反。

为了节省重量,我这样做:

t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]
g_saver = tf.train.Saver(g_vars)
... train model ...
g_saver.save(sess, "weights/generator/gen.ckpt")

当我恢复权重时,我使用相同的模型定义并执行以下操作:

t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

g_saver = tf.train.Saver(g_vars)
g_saver.restore(sess, "./weights/generator/gen.ckpt")

您是否需要执行特殊程序来计算批次规范权重?我是否缺少一些变量集合?

编辑:

我使用以下方法定义批处理规范层:

conv1_norm = tf.contrib.layers.batch_norm(conv1, is_training=training

我发现将 variables_collections=["g_batch_norm_non_trainable"] 添加到 batch_norm 函数中,然后做

g_vars = list(set([var for var in t_vars if 'g_' in var.name] + tf.get_collection("g_batch_norm_non_trainable")))

有效,但对于本应简单的减肥指令来说,这似乎相当复杂。

【问题讨论】:

  • 如何定义批规范层?能否在您的问题中添加这部分或一批规范层的示例?
  • 我刚刚添加了一个编辑,如果有帮助请告诉我!

标签: python tensorflow deep-learning


【解决方案1】:

当您使用 tf.contrib.layers.batch_norm 和默认参数定义批量标准化时,会创建三个变量:betamoving_meanmoving_variance。第一个是唯一可训练的变量,另外两个包含在tf.GraphKeys.GLOBAL_VARIABLES 集合中。

这就是为什么g_vars 使用下一行中的可训练变量定义的原因没有在列表中同时获得moving_meanmoving_variance

g_vars = [var for var in t_vars if 'g_' in var.name]

由于您似乎只想保存生成器变量,因此我建议使用变量范围来定义您的生成器网络。

对随机张量进行上采样并使用批量标准化的示例:

import tensorflow as tf
import numpy as np

input_layer = tf.placeholder(tf.float32, (2, 7, 7, 64))  # (batch, height, width, in_channels)

with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
    # define your generator network here ...
    t_conv_layer = tf.layers.conv2d_transpose(input_layer,
                    filters=32, kernel_size=[3, 3], strides=(2, 2), padding='SAME', name='t_conv_layer')

    batch_norm = tf.contrib.layers.batch_norm(t_conv_layer, is_training=True, scope='my_batch_norm')
    print(batch_norm) # Tensor("generator/my_batch_norm/FusedBatchNorm:0", shape=(2, 14, 14, 32), dtype=float32)

您可以通过打印来检查tf.trainable_variables()tf.global_variables() 的变量列表。 由于可训练变量在here 描述的全局变量列表中,我们可以将g_vars 定义为:

g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')

如果我们检查这个列表,我们将拥有我们想要的批处理规范的所有变量:

for var in g_vars:
    print("variable_name: {:45}, nb_parameters: {}".format(var.name, np.prod(var.get_shape().as_list())))

产生输出:

variable_name: generator/t_conv_layer/kernel:0              , nb_parameters: 18432
variable_name: generator/t_conv_layer/bias:0                , nb_parameters: 32
variable_name: generator/my_batch_norm/beta:0               , nb_parameters: 32
variable_name: generator/my_batch_norm/moving_mean:0        , nb_parameters: 32
variable_name: generator/my_batch_norm/moving_variance:0    , nb_parameters: 32

【讨论】:

  • 这实际上是行不通的。当我像你一样在单个范围内定义生成器,然后收集变量并加载权重时,我得到一个“未找到密钥:生成器/batchnorm/beta/Adam”错误...你知道这是什么原因吗?
  • 这很奇怪。它对我有用。我正在定义生成器和g_vars,如答案和saver = tf.train.Saver(var_list=g_vars)。然后,当使用saver.restore(sess, tf.train.latest_checkpoint(dir_ckpt_file)) 恢复时,模型会恢复。你能分享一个最小的例子来重现你的错误吗?
  • 已解决:使用 saver.restore(sess, path/to/model.ckpt) 导致错误...但使用 tf.train.latest_checkpoint(path/to/dir) 有效...
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2017-11-09
  • 1970-01-01
  • 1970-01-01
  • 2021-02-19
  • 2020-01-03
  • 2022-12-21
  • 2018-02-26
相关资源
最近更新 更多