【发布时间】:2020-04-20 10:50:01
【问题描述】:
我正在尝试从模型中提取所有可训练的权重。
在 pytorch 中,类似的事情可以通过一行 p.grad.data for p in model.parameters() if p.requires_grad 来完成,但是我正在努力在 TF 中想出一个简单的解决方案。
我目前的尝试是这样的:
sess = tf.Session()
... #model initialization and training here
p = model.trainable_weights
p_vals = sess.run(p)
然而,最后一行会产生错误:
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable conv1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv1/bias)
[[{{node conv1/bias/Read/ReadVariableOp}}]]
我在这里做错了什么?我假设会话/图表没有正确链接到模型? 或者它确实是一个初始化问题(但模型能够成功训练)?
【问题讨论】:
-
您应该可以看到模型的权重,无需调用 run 方法。训练结束后,只需调用model.trainable_weights
-
@Giuseppe Angora 之后如何提取权重值?我需要为权重计算 Fisher 信息矩阵,为此我需要以某种方式达到标量值。当我尝试使用索引时,我得到了像
<tf.Tensor 'strided_slice_13:0' shape=() dtype=float32>这样的结构 -
@SzymonMaszke 不,它推荐 sess.run(p) 方法,该方法目前对我不起作用(但似乎对海报工作正常。这意味着我实施错误)。
-
@kravchea 这个帖子的Tensorflow Support 回答怎么样?你在用
tensorflow2.x吗?请记住,tf1.x处于维护模式,如果您不是,请考虑切换。
标签: python tensorflow