【问题标题】:TF 2.0 @tf.function exampleTF 2.0 @tf.function 示例
【发布时间】:2019-03-22 13:20:45
【问题描述】:

autograph 部分的 tensorflow 文档中,我们有以下代码 sn-p

@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

我有一个关于step 变量的小问题,它是一个整数而不是张量,签名支持内置的python 类型,例如整数。因此 tf.equal(step%10,0) 可以简单地更改为 step%10 == 0 对吗?

【问题讨论】:

    标签: python tensorflow tensorflow2.0


    【解决方案1】:

    是的,你是对的。整数变量 step 仍然是 Python 变量,即使转换为它的图形表示。调用tf.autograph.to_code(train.python_function)可以看到转换结果。

    不报告所有代码,只报告step 变量相关部分,您会看到

      def loop_body(loop_vars, loss_1, step_1):
        with ag__.function_scope('loop_body'):
          x, y = loop_vars
          step_1 += 1
    

    仍然是一个 python 操作(否则如果第 1 步是 tf.Tensor,它将是 step_1.assign_add(1))。

    有关 autograph 和 tf.function 的更多信息,我建议阅读文章 https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/,该文章很容易解释转换函数时会发生什么。

    【讨论】:

    • 非常感谢!我只是发现有点违反直觉,在文档中他们没有利用这一点并使用tf.equal 而不是==
    • 真正的优势是将 step 声明为 tf.Variable 并使用 step.assign_add 来增加值 - 这样在图形和 python 之间没有上下文切换,并且执行速度提高了跨度>
    【解决方案2】:

    虽然这在生成的代码中不可见,但 step 变量实际上会被 for 循环自动装箱为张量,该循环正在转换为 TF while_loop。

    您可以通过添加打印语句来验证:

        loss = train_one_step(model, optimizer, x, y)
        print(step)
        if tf.equal(step % 10, 0):
    

    【讨论】:

      猜你喜欢
      • 2019-09-17
      • 1970-01-01
      • 1970-01-01
      • 2019-10-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-08-13
      相关资源
      最近更新 更多