我能够使用以下代码重新创建您的错误 -
重现错误的代码 -
%tensorflow_version 2.x
import tensorflow as tf
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step_with_opt(a, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(a*x - y))
tr_weights = [a]
gradients = tape.gradient(L, tr_weights)
optimizer.apply_gradients(zip(gradients, tr_weights))
return a
a = tf.Variable(2.)
x = tf.Variable([-1.,-1.,-1.], dtype = tf.float32)
y = tf.Variable([2.,2.,2.], dtype = tf.float32)
train_step_with_opt(a, x, y, opt1) # works
print("First Run was fine")
train_step_with_opt(a, x, y, opt2) # fails
print("Second Run was fine")
输出 -
First Run was fine
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-2-52022dc5007d> in <module>()
21 train_step_with_opt(a, x, y, opt1) # works
22 print("First Run was fine")
---> 23 train_step_with_opt(a, x, y, opt2) # fails
24 print("Second Run was fine")
7 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
ValueError: in user code:
<ipython-input-1-4386b333360b>:13 train_step_with_opt *
optimizer.apply_gradients(zip(gradients, tr_weights))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:478 apply_gradients **
self._create_all_weights(var_list)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:661 _create_all_weights
_ = self.iterations
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:668 __getattribute__
return super(OptimizerV2, self).__getattribute__(name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:793 iterations
aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:997 add_weight
aggregation=aggregation)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/base.py:743 _add_variable_with_custom_getter
**kwargs_for_getter)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer_utils.py:141 make_variable
shape=variable_shape if variable_shape else None)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:259 __call__
return cls._variable_v1_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:220 _variable_v1_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:66 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:511 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
解决方案 - 由于您正在尝试在 TF 2.0 中使用函数装饰器,请在导入 TensorFlow 后立即启用运行函数。我通过在import tensorflow 之后添加tf.config.experimental_run_functions_eagerly(True) 解决了这个问题。
你可以在tensorflow官网link找到更多关于tf.config.experimental_run_functions_eagerly的信息。
固定代码 -
%tensorflow_version 2.x
import tensorflow as tf
tf.config.experimental_run_functions_eagerly(True)
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step_with_opt(a, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(a*x - y))
tr_weights = [a]
gradients = tape.gradient(L, tr_weights)
optimizer.apply_gradients(zip(gradients, tr_weights))
return a
a = tf.Variable(2.)
x = tf.Variable([-1.,-1.,-1.], dtype = tf.float32)
y = tf.Variable([2.,2.,2.], dtype = tf.float32)
train_step_with_opt(a, x, y, opt1) # works
print("First Run was fine")
train_step_with_opt(a, x, y, opt2) # fails
print("Second Run was fine")
输出 -
First Run was fine
Second Run was fine
希望这能回答您的问题。快乐学习。