【问题标题】:Tensorflow function doesn't change attribute's attributeTensorflow 函数不会改变属性的属性
【发布时间】:2022-01-06 05:27:45
【问题描述】:
Tf 函数不会改变对象的属性
class f:
v = 7
def __call__(self):
self.v = self.v + 1
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
预期打印:
7
8
8
9
而是:
7
8
7
8
当我删除 @tf.function 装饰器时,一切都按预期工作。如何使用@tf.function 使我的函数按预期工作
【问题讨论】:
标签:
python
python-3.x
tensorflow
tensorflow2.0
【解决方案1】:
此行为记录在 here:
副作用,如打印、附加到列表和改变全局变量,可能会在函数内部出现意外行为,有时会执行两次或不全部执行。它们仅在您第一次使用一组输入调用函数时发生。之后,跟踪的 tf.Graph 被重新执行,而不执行 Python 代码。一般的经验法则是避免在逻辑中依赖 Python 副作用,只使用它们来调试跟踪。否则,tf.data、tf.print、tf.summary、tf.Variable.assign 和 tf.TensorArray 等 TensorFlow API 是确保每次调用都由 TensorFlow 运行时执行代码的最佳方式。
所以,也许可以尝试使用tf.Variable 来查看预期的变化:
import tensorflow as tf
class f:
v = tf.Variable(7)
def __call__(self):
self.v.assign_add(1)
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)
7
8
8
9