【问题标题】:How to use GradientTape with AutoGraph in Tensorflow 2?如何在 Tensorflow 2 中使用 GradientTape 和 AutoGraph?
【发布时间】:2019-09-07 05:36:22
【问题描述】:

我不知道如何在 Tensorflow 2 中的 AutoGraph 上运行 GradientTape 代码。

我想在 TPU 上运行 GradientTape 代码。我想先在 CPU 上测试它。 TPU 代码使用 AutoGraph 会运行得更快。我尝试查看输入变量,并尝试将参数传递给包含 GradientTape 的函数,但都失败了。

我在这里做了一个可重现的例子: https://colab.research.google.com/drive/1luCk7t5SOcDHQC6YTzJzpSixIqHy_76b#scrollTo=9OQUYQtTTYIt

代码及对应输出如下: 它们都以import tensorflow as tf开头

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)

输出:tf.Tensor(6.0, shape=(), dtype=float32) 说明:使用 Eager Execution,GradientTape 产生渐变。

@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)
compute_me()

输出:Tensor("AddN:0", shape=(), dtype=float32) 说明:在 TF2 中对 GradientTape 使用 AutoGraph 会导致渐变为空

@tf.function
def compute_me_args(x):
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)    
x = tf.constant(3.0)
compute_me_args(x)

输出:Tensor("AddN:0", shape=(), dtype=float32) 说明:传入参数也失败

我希望所有单元格都输出 tf.Tensor(6.0, shape=(), dtype=float32),但使用 AutoGraph 的单元格输出 Tensor("AddN:0", shape=(), dtype=float32)

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    它不会“失败”,只是 print,如果在 tf.function 的上下文中使用(即在图形模式下)将打印符号张量,而这些张量没有值。试试这个:

    @tf.function
    def compute_me():
        x = tf.constant(3.0)
        with tf.GradientTape() as g:
            g.watch(x)
            y = x * x
        dy_dx = g.gradient(y, x) # Will compute to 6.0
        tf.print(dy_dx)
    compute_me()
    

    这应该打印6。您需要做的就是改用tf.print,它足够“智能”,可以打印实际值(如果有)。或者,使用返回值:

    @tf.function
    def compute_me():
        x = tf.constant(3.0)
        with tf.GradientTape() as g:
            g.watch(x)
            y = x * x
        dy_dx = g.gradient(y, x) # Will compute to 6.0
        return dy_dx
    result = compute_me()
    print(result)
    

    输出类似<tf.Tensor: id=43, shape=(), dtype=float32, numpy=6.0> 的东西。您可以看到值 (6.0) 在此处也可见。使用result.numpy() 获取6.0

    【讨论】:

    • 测试了代码示例,它们适用于 Tensorflow 2.0.0-rc0 和 Tensorflow 2.0.0-beta1
    猜你喜欢
    • 1970-01-01
    • 2020-04-22
    • 2023-04-02
    • 2019-10-28
    • 1970-01-01
    • 1970-01-01
    • 2020-07-31
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多