【发布时间】:2017-09-01 10:36:22
【问题描述】:
我想在一个 python 应用程序中使用 model.fit() 并行训练一些不同的模型。使用的模型没有必要的共同点,它们是在不同的时间在一个应用程序中启动的。
首先,我在单独的线程中启动一个没有问题的 model.fit(),然后是主线程。如果我现在想启动第二个 model.fit(),我会收到以下错误消息:
Exception in thread Thread-1:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'hidden_1/BiasAdd': Unknown input node 'hidden_1/MatMul'
它们都是通过相同的代码行从一个方法开始的:
start_learn(self:)
tf_session = K.get_session() # this creates a new session since one doesn't exist already.
tf_graph = tf.get_default_graph()
keras_learn_thread.Learn(learning_data, model, self.env_cont, tf_session, tf_graph)
learning_results.start()
调用的类/方法如下所示:
def run(self):
tf_session = self.tf_session # take that from __init__()
tf_graph = self.tf_graph # take that from __init__()
with tf_session.as_default():
with tf_graph.as_default():
self.learn(self.learning_data, self.model, self.env_cont)
# now my learn method where model.fit() is located is being started
我想我必须以某种方式为每个线程分配一个新的 tf_session 和一个新的 tf_graph。但我不太确定。我会为每一个简短的想法感到高兴,因为我现在坐这个太久了。
谢谢
【问题讨论】:
标签: python multithreading tensorflow keras