【发布时间】:2020-02-26 18:49:49
【问题描述】:
我正在使用 ScipyOptimizerInterface 来训练 tensorflow 模型。 (tensorflow 1.13.1)
在训练过程中,如果loss 值低于阈值,我希望在超过阈值之前停止训练过程并保存模型。
以下是我尝试过的脚本。想法是引发异常以退出optimizer.minimize,然后使用tf.train.Saver 保存模型。
但是,这不起作用。正如您通过比较初始loss 值 和保存的模型计算的loss 值 所看到的。这两个值相同,表示保存的是初始随机模型,而不是所需模型。
从@Patol75 的回答中,我了解到最好的模型没有保存,因为更新的tf.Variables 在解释培训课程时会死掉。
如何保存所需的模型?
import numpy as np
import tensorflow as tf
from tensorflow.contrib.opt import ScipyOptimizerInterface
class test(Exception):
pass
def construct_graph():
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder('float', shape = [None, 1])
w = tf.get_variable('w_0', shape = [1, 1], initializer = tf.contrib.layers.xavier_initializer())
b = tf.get_variable('b_0', shape = [1], initializer = tf.contrib.layers.xavier_initializer())
y_out = tf.matmul(x, w) + b
y = tf.placeholder('float', shape = [None, 1])
loss = tf.reduce_mean(tf.square(y - y_out))
return graph, x, y, loss
# create example datasets
x_train = np.linspace(1, 6, 100) + 0.1 * np.random.random(100)
x_train = x_train.reshape(100, 1)
y_train = np.sin(x_train)
x_val = np.linspace(6, 11, 100)
x_val = x_val.reshape(100, 1)
y_val = np.sin(x_val)
tf.reset_default_graph()
graph, x, y, loss = construct_graph()
feeddict_train = {x: x_train, y: y_train}
feeddict_val = {x: x_val, y: y_val}
with graph.as_default():
def step_callbackfun(x):
global iteration
train_part, val_part = valfunc_train(x), valfunc_val(x)
print('%10.5f %10.5f' % (*train_part, *val_part))
iteration += 1
if iteration == 5:
raise test()
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
optimizer = ScipyOptimizerInterface(loss, method='l-BFGS-b')
iteration = 0
valfunc_train = optimizer._make_eval_func(tensors=loss, session=sess, feed_dict=feeddict_train, fetches=[])
valfunc_val = optimizer._make_eval_func(tensors=loss, session=sess, feed_dict=feeddict_val, fetches=[])
print('The initial loss is %f' % sess.run(loss, feeddict_train))
try:
optimizer.minimize(sess, feeddict_train, step_callback=step_callbackfun)
except test:
saver.save(sess, 'model/model.ckpt')
graph2, x2, y2, loss2 = construct_graph()
with tf.Session(graph=graph2) as sess2:
feeddict_two = {x2: x_train, y2: y_train}
sess2.run(tf.global_variables_initializer())
saver2 = tf.train.Saver()
saver2.restore(sess2, 'model/model.ckpt')
loss_val2 = sess2.run(loss2, feeddict_two)
print('Outside', loss_val2)
【问题讨论】:
标签: python tensorflow scipy scipy-optimize