【发布时间】:2020-07-07 23:29:42
【问题描述】:
我正在尝试创建一个过滤器,该过滤器取决于当前的训练global_step,但我没有正确地这样做。
首先,我不能在下面的代码中使用tf.train.get_or_create_global_step(),因为它会抛出
ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
这就是我尝试使用tf.get_default_graph().get_name_scope() 获取范围的原因,并且在该上下文中我能够“get”全局步骤:
def filter_examples(example):
scope = tf.get_default_graph().get_name_scope()
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
current_step = tf.train.get_or_create_global_step()
subtokens_by_step = tf.floor(current_step / curriculum_step_update)
max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)
return tf.size(example['targets']) <= max_subtokens
dataset = dataset.filter(filter_examples)
问题在于它似乎没有按我预期的那样工作。根据我的观察,上面代码中的current_step 似乎一直为0(我不知道,只是根据我的观察,我假设)。
唯一似乎有所作为,听起来很奇怪,就是重新开始训练。我认为,同样基于观察,在这种情况下,current_step 将是此时培训的实际当前步骤。但值本身不会随着训练的继续而更新。
如果有办法获取当前步骤的 实际 值并在我的过滤器中使用它吗?
环境
张量流 1.12.1
【问题讨论】:
-
global_step在哪里更新? -
@rvinas 这是一个很好的问题.. 我正在使用tensor2tensor (t2t) 并且只实现了我自己的
problem。上面的代码在 t2t 数据管道执行期间被调用并且应该只返回一个tf.data.Dataset。我希望能够以某种方式获取由 t2t 创建的global_step。 -
我明白了。我不熟悉 t2t 框架,但我们应该确保
global_step在某处得到更新,否则,它的值保持为 0 是有道理的。也许另一个选择(可能不优雅)是维护您的自定义 @987654334 @? -
@rvinas 所以
global_step确实得到了更新,但我认为我无法在图表中获得对该特定变量的引用。如前所述,例如,get_or_create_global_step()确实会引发异常。正如你所说,使用我自己的global_step或counter变量是可以的,但我真的不知道如何才能完成这样的事情。有没有办法例如每当在当前上下文中执行步骤时更新变量? -
@rvinas 您对
tf.control_dependencies的想法正在发挥作用。这不是我一开始希望它做的方式,但它没关系并且服务于它的目的。如果您想提供答案,那么赏金就是您的。 :)