【发布时间】:2019-02-11 18:35:26
【问题描述】:
我存储了一个包含 .meta、.index、checkpoint 和 .data-0001 文件的 Tensorflow 模型。我使用以下方法恢复我的图表和模型:
model = tf.train.import_meta_graph("models/model.meta")
model.restore(sess, tf.train.latest_checkpoint("models/"))
我恢复了一些变量,例如权重和偏差,但我还需要恢复损失函数。我的模型使用nce_loss。
本质上,我想在给定特定输入的情况下获得损失函数的梯度,我不必重新定义损失变量,只需从恢复的版本调用操作即可。所以:
loss = graph.get_operation_by_name("loss")
grads = tf.gradients(loss,loss.inputs)
在这里我收到以下错误消息:
File "/tmp/fgsm.py", line 114, in main
grads = tf.gradients(loss,loss.inputs)
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 630, in gradients
gate_gradients, aggregation_method, stop_gradients)
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 675, in _GradientsHelper
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1377, in convert_n_to_tensor_or_indexed_slices
values=values, dtype=dtype, name=name, as_ref=False)
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1348, in internal_convert_n_to_tensor_or_indexed_slices
value, dtype=dtype, name=n, as_ref=as_ref))
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1307, in internal_convert_to_tensor_or_indexed_slices
value, dtype=dtype, name=name, as_ref=as_ref)
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1146, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/tmp/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 6168, in _operation_conversion_error
name, as_ref))
TypeError: Can't convert Operation 'loss' to Tensor (target dtype=None, name='y_0', as_ref=False)
我在这里做错了什么?
编辑:
所以切换到
loss = graph.get_tensor_by_name("loss:0")
我可以成功获得我的损失张量。现在,在给定恢复的损失函数的情况下,如何获得输入的梯度?
nce_loss 有一个“输入”参数,我想计算给定损失函数和输入参数的梯度。我该如何使用tf.gradients?当我做tf.gradients(loss,loss.inputs) 时,我得到一个错误
AttributeError: 'Tensor' object has no attribute 'inputs'
【问题讨论】:
标签: python tensorflow gradient