【问题标题】:Tensorflow checkpoint does not save all variablesTensorFlow 检查点不保存所有变量
【发布时间】:2018-11-17 15:18:01
【问题描述】:

我在 tensorflow-gpu 1.8.0 中实现了 MLP,并使用 Hyperopt 来寻找最佳参数配置。每次进一步最小化损失函数时都会创建一个检查点文件。检查点文件总是被覆盖,在进程结束时,我只得到了这些文件:

检查点

Model_1_checkpoint.ckpt.data-00000-of-00001

Model_1_checkpoint.ckpt.index

Model_1_checkpoint.meta

我在下面展示了训练 MLP(在函数内部)的代码:

# Hyperparameters
n_step= np.round(parameters['step'],3)
n_hidden= np.int(parameters['number_neurons'])
n_bias= np.round(parameters['bias'],3)
n_batch= np.int(parameters['batch'])

# General variables
N_instances= xtrain_data_1_T60.shape[0]
N_input= xtrain_data_1_T60.shape[1]
N_classes= enc_ytrain_data_1_T60.shape[1]
N_epochs= 500
display_step= 100

# Reset graph
tf.reset_default_graph()

# Placeholders
X= tf.placeholder(name= "Logs", dtype= tf.float32, shape= [None, N_input])
y= tf.placeholder(name= "Facies", dtype= tf.float32, shape= [None, N_classes])

# MLP network architecture
input_layer= tf.layers.dense(X, units= N_input, activation= None, 
                             kernel_initializer= tf.keras.initializers.glorot_normal(1969),
                             bias_initializer= tf.keras.initializers.Zeros())

hidden_layer= tf.layers.dense(input_layer, units= n_hidden, activation= tf.nn.tanh, 
                              kernel_initializer= tf.keras.initializers.he_normal(1969),
                              bias_initializer= tf.keras.initializers.Constant(n_bias))

output_layer= tf.layers.dense(hidden_layer, units= N_classes, activation= tf.nn.softmax,
                              kernel_initializer= tf.keras.initializers.he_normal(1969),
                              bias_initializer= tf.keras.initializers.Zeros(), name= "mlp_output")

loss_op= tf.reduce_mean(tf.keras.backend.binary_crossentropy(y, output_layer))

optimizer= tf.train.GradientDescentOptimizer(learning_rate= n_step).minimize(loss_op)

# Initialize variables
init= tf.initialize_all_variables() #tf.global_variables_initializer()


with tf.Session() as sess:
    sess.run(init)

    # Training loop
    for epoch in range(0, N_epochs):
        avg_cost = 0.
        total_batch= np.int(N_instances/n_batch)
        start_idx= 0
        end_idx= n_batch

        for i in range(0, total_batch):
            batchx= xtrain_data_1_T60[start_idx:end_idx,:]
            batchy= enc_ytrain_data_1_T60[start_idx:end_idx,:]

            _, c= sess.run([optimizer, loss_op], feed_dict= {X: batchx, y: batchy})
            avg_cost += c/total_batch

            # Set next batch
            start_idx += n_batch
            end_idx += n_batch
            if (end_idx > N_instances):
                end_idx= N_instances

        if (epoch % display_step == 0):
            print("Epoch : %03d/%03d cost : %.4f\n"%(epoch, N_epochs, avg_cost))

    print("Optimization finished\n")

    prediction_1= sess.run(output_layer, feed_dict= {X: xvalidation_data_1_V40})
    prediction_1= prediction_1.argmax(axis= 1) + 1

    # Initialize a saver to save the current best model
    saver= tf.train.Saver(max_to_keep= 1)

    # Only check for prediction results with 3 lithofacies. Otherwise, I assign a dummy error and accuracy
    if len(np.unique(prediction_1)) == 3:
        error= 1. - metrics.recall_score(yvalidation_data_1_V40, prediction_1, average= 'micro')
        accuracy= metrics.accuracy_score(yvalidation_data_1_V40, prediction_1)

        global temp_error
        if (error < temp_error):
            temp_error= error
            saver.save(sess, '{}/{}'.format(checkpoint_path, checkpoint_name))
            print("Best model saved in file: ", '{}/{}'.format(checkpoint_path, checkpoint_name))
            print()
    else:
        error= 3
        accuracy= 0.00

    print("Error: {}".format(error))
    print("Accuracy: {:.2%}".format(accuracy))
    print("Predicted number of lithofacies: {}\n".format(len(np.unique(prediction_1))))

    sess.close()

然后,在同一个脚本中,我恢复保存检查点以计算预测:

tf.reset_default_graph()

# Restore the best model and predict again
with tf.Session() as sess:
    new_saver= tf.train.import_meta_graph(checkpoint_path + "/" + checkpoint_name + ".meta")
    new_saver.restore(sess, checkpoint_path)

    # Retrieve placeholder from restored graph
    X= best_model_1.get_tensor_by_name('Logs:0')

    # Retrieve output layer of MLP network to compute predictions
    pred= best_model_1.get_tensor_by_name('mlp_output/kernel:0')

    model_prob_density_1= sess.run(pred, feed_dict= {X: voting_data})

不幸的是,“new_saver.restore(...)”行导致以下错误消息:

NotFoundError(回溯见上文):检查点中未找到密钥密集/偏差 [[节点:save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"]( _arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]] [[节点:save/RestoreV2/_9 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task :0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_14_save/RestoreV2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]

我回到代码的第一块尝试这些更改:

a) saver= tf.train.Saver(max_to_keep= 1) --> saver= tf.train.Saver(tf.global_variables(), max_to_keep= 1)

b) saver= tf.train.Saver(max_to_keep= 1) --> saver= tf.train.Saver(tf.trainable_variables(), max_to_keep= 1)

然而,我仍然收到同样的错误信息。

有什么建议吗?

非常感谢, 伊万

【问题讨论】:

  • 为避免由于检查点路径错误而可能出现的错误,您应该始终使用 os 模块创建路径字符串,特别是 os.path.join(...)。

标签: variables tensorflow restore checkpoint


【解决方案1】:

抱歉可能造成混淆,第二个块中的代码版本不正确。

这是正确的版本:

tf.reset_default_graph()

with tf.Session() as sess:

    new_saver= tf.train.import_meta_graph(checkpoint_path + "/" + checkpoint_name + ".meta")
    new_saver.restore(sess, checkpoint_path)
    graph= tf.get_default_graph()

    # Retrieve placeholder from restored graph
    X= graph.get_tensor_by_name('Logs:0')

    # Retrieve output layer of MLP network to compute predictions
    pred= graph.get_tensor_by_name('mlp_output/kernel:0')

    model_prob_density_1= sess.run(pred, feed_dict= {X: voting_data})

【讨论】:

  • 您可以使用编辑功能来编辑您的原始问题以包含新信息。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2018-12-03
  • 1970-01-01
  • 1970-01-01
  • 2016-03-10
  • 2019-11-15
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多