【发布时间】:2018-02-13 15:58:53
【问题描述】:
我们计划在 TensorFlow 中实施分布式训练。为此,我们使用 TensorFlow 分布式 (https://www.tensorflow.org/deploy/distributed)。 我们能够使用“图复制训练之间的异步”来实现分布式训练。下面是代码sn-p。
.....
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
img_width, img_height = 124, 124
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
####### Assigns ops to the local worker by default.#######
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
######## set Keras learning phase to train #######
K.set_learning_phase(1)
# do not initialize variables on the fly
K.manual_variable_initialization(True)
if K.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
X= tf.placeholder(tf.float32, shape= [None, img_width, img_height, 3], name= "X")
Y= tf.placeholder(tf.float32, shape= [None, n_classes], name="Y")
print("Building keras model")
....
....
####### Defining our total loss #######
###### Defining our TF Optimizer and passing hyperparameters ######
.......
...........
...............
我们定义了我们的培训主管,如下所示。
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir=logdir,
init_op=init_op,
saver=saver,
summary_op=summary_op,
global_step=global_step)
我们还使用以下代码 sn-p 初始化了主管。
with sv.prepare_or_wait_for_session(server.target) as sess:
然后我们在训练过程中传递了不同的批次。直到这部分一切正常。但是,当我们尝试保存/导出模型以供 TensorFlow 服务时,它并没有生成正确的检查点文件集,以便我们可以在生产中提供它。通过 tensorflow_model_server 托管检查点文件时,我们遇到了以下错误。
Loading servable: {name: default version: 2} failed: Invalid argument:
Cannot assign a device for operation 'init_all_tables': Operation was
explicitly assigned to /job:worker/task:0 but available devices are [
/job:localhost/replica:0/task:0/cpu:0 ]. Make sure the device specification
refers to a valid device.[[Node: init_all_tables =
NoOp[_device="/job:worker/task:0"]()]]
请注意,我们还尝试了以下方法来保存训练好的图。
i) 保存的模型构建器
builder = saved_model_builder.SavedModelBuilder(export_path)
ii) 模型导出器
export_path = "/saved_graph/"
model_exporter.export(export_path, sess)
iii) tf.train.Saver - 函数
tf.train.Saver
- 但在上述任何情况下,我们都没有看到成功。
我们找不到任何直接显示完整示例或详细解释的文章。我们浏览了以下参考链接。
https://github.com/tensorflow/tensorflow/issues/5439 https://github.com/tensorflow/tensorflow/issues/5110 Running distributed Tensorflow with InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float
任何建议或参考都会有很大帮助。
谢谢。
----------------------------------------------- ----------------------
根据建议,我们已尝试在导出模型时使用“clear_devices=True”,但这并没有帮助。下面是我们使用的代码sn-p。
for epoch in range(training_epochs):
epoch_num=0
batch_count = int(num_img/batch_size)
count = 0
for i in range(batch_count):
epoch_num=0
# This will create batches out of out Training dataset and it will
pe passed to the feed_dict
batch_x, batch_y =
next_batch(batch_size,epoch_num,train_data,train_labels,num_img)
# perform the operations we defined earlier on batch
_, cost, step = sess.run([train_op, cross_entropy, global_step],
feed_dict={X: batch_x, Y: batch_y})
sess.run(tf.global_variables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map= {
"magic_model":
tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"image": X},
outputs= {"prediction": preds})
}, clear_devices=True)
builder.save()
sv.stop()
print("Done!!")
当我们使用 clear_devices=True 时,我们遇到了错误。
Error: Traceback (most recent call last):
File "insulator_classifier.py", line 370, in <module>
tf.app.run()
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "insulator_classifier.py", line 283, in main
}, clear_devices=False)
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/saved_model/builder_impl.py", line 364, in
add_meta_graph_and_variables
allow_empty=True)
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/training/saver.py", line 1140, in __init__
self.build()
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/training/saver.py", line 1172, in build
filename=self._filename)
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/training/saver.py", line 677, in build
filename_tensor = constant_op.constant(filename or "model")
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/framework/constant_op.py", line 106, in constant
attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/framework/ops.py", line 2582, in create_op
self._check_not_finalized()
File "/root/anaconda3/lib/python3.6/site-
packages/tensorflow/python/framework/ops.py", line 2290, in
_check_not_finalized
raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.
我们在这里缺少什么??
----------------------------------------------- ----------------------
进一步更新:
我们可以看到它适用于@Tianjin Gu 的建议 2) 请看下面的代码 sn -p
X= tf.placeholder(tf.float32, shape= [None, img_width, img_height, 3], name= "X")
Y= tf.placeholder(tf.float32, shape= [None, n_classes], name="Y")
....
....
model_exporter = exporter.Exporter(saver)
model_exporter.init(
tf.get_default_graph().as_graph_def(),
named_graph_signatures={
'inputs': exporter.generic_signature({'input': X}),
'outputs': exporter.generic_signature({'output': Y})}, clear_devices=True)
export_path = "/export_path"
当我们出口时,我们看到了这个警告 -
WARNING:tensorflow:From test_classifier.py:283: Exporter.export (from tensorflow.contrib.session_bundle.exporter) is deprecated and will be removed after 2017-06-30.
因此,理想情况下,我们应该使用“tf.saved_model.builder.SavedModelBuilder”——但由于某种原因,这不起作用。
还有什么建议吗?
谢谢。
【问题讨论】:
-
澄清一下,您在使用 SavedModelBuilder 时遇到的剩余错误是什么?
标签: tensorflow model export distributed tensorflow-serving