【发布时间】:2018-03-20 10:52:12
【问题描述】:
我想从这里https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot 在 GPU 上训练“standford chatbot”,但它不使用我的 GPU,但安装了所有需要的库(CuNN、CUDA、tensorflow-gpu 等) 我试过了:
def train():
""" Train the bot """
test_buckets, data_buckets, train_buckets_scale = _get_buckets()
# in train mode, we need to create the backward path, so forwrad_only is False
model = ChatBotModel(False, config.BATCH_SIZE)
model.build_graph()
saver = tf.train.Saver(var_list=tf.trainable_variables())
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=True)) as sess:
print('Start training')
sess.run(tf.global_variables_initializer())
_check_restore_parameters(sess, saver)
iteration = model.global_step.eval()
total_loss = 0
while True:
skip_step = _get_skip_step(iteration)
bucket_id = _get_random_bucket(train_buckets_scale)
encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id],
bucket_id,
batch_size=config.BATCH_SIZE)
start = time.time()
_, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False)
total_loss += step_loss
iteration += 1
if iteration % skip_step == 0:
print('Итерация {}: потеря {}, время {}'.format(iteration, total_loss/skip_step, time.time() - start))
start = time.time()
total_loss = 0
saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step)
if iteration % (10 * skip_step) == 0:
# Run evals on development set and print their loss
_eval_test_set(sess, model, test_buckets)
start = time.time()
sys.stdout.flush()
但它总是显示:
InvalidArgumentError (see above for traceback): Cannot assign a device to node 'save/Const': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
托管调试信息: 托管组有以下类型和设备: 常量:CPU 身份:CPU [[节点:save/Const = Constdtype=DT_STRING, value=Tensor, _device="/device:GPU:0"]]
是否有一些用于 tensorflow 的配置文件,我可以指定仅使用 GPU 或其他方式(我尝试了“with tf.device("/gpu:0"):" and device_count={'GPU': 1} ))
【问题讨论】:
标签: tensorflow neural-network python-3.5 lstm tensorflow-gpu