【问题标题】:Getting values of trainable parameters in tensorflow在张量流中获取可训练参数的值
【发布时间】:2020-04-20 10:50:01
【问题描述】:

我正在尝试从模型中提取所有可训练的权重。 在 pytorch 中,类似的事情可以通过一行 p.grad.data for p in model.parameters() if p.requires_grad 来完成,但是我正在努力在 TF 中想出一个简单的解决方案。

我目前的尝试是这样的:

sess = tf.Session()

... #model initialization and training here

p = model.trainable_weights
p_vals = sess.run(p)

然而,最后一行会产生错误:

  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable conv1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv1/bias)
     [[{{node conv1/bias/Read/ReadVariableOp}}]]

我在这里做错了什么?我假设会话/图表没有正确链接到模型? 或者它确实是一个初始化问题(但模型能够成功训练)?

【问题讨论】:

  • 您应该可以看到模型的权重,无需调用 run 方法。训练结束后,只需调用model.trainable_weights
  • @Giuseppe Angora 之后如何提取权重值?我需要为权重计算 Fisher 信息矩阵,为此我需要以某种方式达到标量值。当我尝试使用索引时,我得到了像 <tf.Tensor 'strided_slice_13:0' shape=() dtype=float32> 这样的结构
  • @SzymonMaszke 不,它推荐 sess.run(p) 方法,该方法目前对我不起作用(但似乎对海报工作正常。这意味着我实施错误)。
  • @kravchea 这个帖子的Tensorflow Support 回答怎么样?你在用tensorflow2.x吗?请记住,tf1.x 处于维护模式,如果您不是,请考虑切换。

标签: python tensorflow


【解决方案1】:

使用自定义回调 Fn 会更轻松,然后使用自定义操作进行交易!

class custom_callback(tf.keras.callbacks.Callback): 
tf.summary.create_file_writer(val_dir)

def _val_writer(self):
    if 'val' not in self._writers:
        self._writers['val'] = tf.summary.create_file_writer(val_dir)
    return self._writers['val']

def on_epoch_end(self, epoch, logs={}):
    print('weights: ' + str(self.model.get_weights()))
    
    if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
        with tf.summary.record_if(True): # self._val_writer.as_default():
            step = ''
            for name, value in logs.items():
                tf.summary.scalar(
                'evaluation_' + name + '_vs_iterations',
                value,
                step=self.model.optimizer.iterations.read_value(),
                )
            print('step :' + str(self.model.optimizer.iterations.read_value()))

    if(logs['accuracy'] == None) : pass
    else:
        if(logs['accuracy']> 0.90):
            self.model.stop_training = True

custom_callback = custom_callback()

history = model_highscores.fit(batched_features, epochs=99 ,validation_data=(dataset.shuffle(len(list_image))), callbacks=[custom_callback]) 

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2021-06-28
    • 2016-11-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-12-16
    相关资源
    最近更新 更多