【问题标题】:How to get current global_step in data pipeline如何在数据管道中获取当前的 global_step
【发布时间】:2020-07-07 23:29:42
【问题描述】:

我正在尝试创建一个过滤器,该过滤器取决于当前的训练global_step,但我没有正确地这样做。

首先,我不能在下面的代码中使用tf.train.get_or_create_global_step(),因为它会抛出

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

这就是我尝试使用tf.get_default_graph().get_name_scope() 获取范围的原因,并且在该上下文中我能够“get”全局步骤:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens


dataset = dataset.filter(filter_examples)

问题在于它似乎没有按我预期的那样工作。根据我的观察,上面代码中的current_step 似乎一直为0(我不知道,只是根据我的观察,我假设)。

唯一似乎有所作为,听起来很奇怪,就是重新开始训练。我认为,同样基于观察,在这种情况下,current_step 将是此时培训的实际当前步骤。但值本身不会随着训练的继续而更新。

如果有办法获取当前步骤的 实际 值并在我的过滤器中使用它吗?


环境

张量流 1.12.1

【问题讨论】:

  • global_step 在哪里更新?
  • @rvinas 这是一个很好的问题.. 我正在使用tensor2tensor (t2t) 并且只实现了我自己的problem。上面的代码在 t2t 数据管道执行期间被调用并且应该只返回一个tf.data.Dataset。我希望能够以某种方式获取由 t2t 创建的global_step
  • 我明白了。我不熟悉 t2t 框架,但我们应该确保 global_step 在某处得到更新,否则,它的值保持为 0 是有道理的。也许另一个选择(可能不优雅)是维护您的自定义 @987654334 @?
  • @rvinas 所以global_step 确实得到了更新,但我认为我无法在图表中获得对该特定变量的引用。如前所述,例如,get_or_create_global_step() 确实会引发异常。正如你所说,使用我自己的global_stepcounter 变量是可以的,但我真的不知道如何才能完成这样的事情。有没有办法例如每当在当前上下文中执行步骤时更新变量?
  • @rvinas 您对tf.control_dependencies 的想法正在发挥作用。这不是我一开始希望它做的方式,但它没关系并且服务于它的目的。如果您想提供答案,那么赏金就是您的。 :)

标签: tensorflow tensor2tensor


【解决方案1】:

正如我们在 cmets 中所讨论的,拥有和更新您自己的计数器可能是使用 global_step 变量的替代方法。 counter 变量可以更新如下:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

使用tf.control_dependencies 允许将counter 的更新“附加”到计算图中的路径。然后,您可以在任何需要的地方使用counter 变量。

【讨论】:

    【解决方案2】:

    如果您在数据集中使用变量,则需要在 tf 1.x 中重新初始化迭代器。

    iterator = tf.compat.v1.make_initializable_iterator(dataset)
    init = iterator.initializer
    tensors = iterator.get_next()
    
    with tf.compat.v1.Session() as sess:
        for epoch in range(num_epochs):
            sess.run(init)
            for example in range(num_examples):
                tensor_vals = sess.run(tensors)
    

    【讨论】:

    • 请注意,您最好对初始数据集进行排序,预先计算长度截断值,然后使用Dataset.take
    • 抱歉,这不是一个适用于 tensor2tensor 框架的选项。
    猜你喜欢
    • 2019-06-17
    • 2023-01-10
    • 2020-12-04
    • 2020-02-19
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-06-10
    • 2020-05-24
    相关资源
    最近更新 更多