【发布时间】:2019-09-07 05:36:22
【问题描述】:
我不知道如何在 Tensorflow 2 中的 AutoGraph 上运行 GradientTape 代码。
我想在 TPU 上运行 GradientTape 代码。我想先在 CPU 上测试它。 TPU 代码使用 AutoGraph 会运行得更快。我尝试查看输入变量,并尝试将参数传递给包含 GradientTape 的函数,但都失败了。
我在这里做了一个可重现的例子: https://colab.research.google.com/drive/1luCk7t5SOcDHQC6YTzJzpSixIqHy_76b#scrollTo=9OQUYQtTTYIt
代码及对应输出如下:
它们都以import tensorflow as tf开头
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)
输出:tf.Tensor(6.0, shape=(), dtype=float32)
说明:使用 Eager Execution,GradientTape 产生渐变。
@tf.function
def compute_me():
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
print(dy_dx)
compute_me()
输出:Tensor("AddN:0", shape=(), dtype=float32)
说明:在 TF2 中对 GradientTape 使用 AutoGraph 会导致渐变为空
@tf.function
def compute_me_args(x):
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
print(dy_dx)
x = tf.constant(3.0)
compute_me_args(x)
输出:Tensor("AddN:0", shape=(), dtype=float32)
说明:传入参数也失败
我希望所有单元格都输出 tf.Tensor(6.0, shape=(), dtype=float32),但使用 AutoGraph 的单元格输出 Tensor("AddN:0", shape=(), dtype=float32)。
【问题讨论】:
标签: python-3.x tensorflow