【问题标题】:Reading the dataset as batches to train批量读取数据集进行训练
【发布时间】:2020-01-14 15:17:44
【问题描述】:

我正在尝试读取 cifar10 数据集并将其用于训练模型,因此我尝试读取批次并运行如下会话:

 # Optimizer
    opt = tf.train.AdamOptimizer(0.0001)
    global_step = tf.get_variable('global_step', initializer=tf.constant(0), trainable=False)
    train_op = opt.apply_gradients(zip(grads, var_list), global_step=global_step)

    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

    image_batch, label_batch = tf.train.batch([x_train, y_train], batch_size=batch_size)
    #image_batch_uint8 = tf.cast(image_batch, tf.uint8)

    # Train
    with tf.Session() as sess:

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        sess.run(tf.global_variables_initializer())
        for i in range(10000000):
            _loss_value, _reward_value, _ = sess.run([loss, reward, train_op], feed_dict={
                images_ph: image_batch,
                labels_ph: label_batch
            })
            if i % 100 == 0:
                print('iter: ', i, '\tloss: ', _loss_value, '\treward: ', _reward_value)

但是我得到了这个错误:

 File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1086, in _run
    'feed with key ' + str(feed) + '.')
The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("batch:0", shape=(32, 50000, 32, 32, 3), dtype=uint8) which was passed to the feed with key Tensor("Placeholder:0", shape=(?, 1024), dtype=float32).

我做错了什么?如何确保将所有数据集作为 epoch 提供,是否有更直接的方法来提供数据集??

【问题讨论】:

  • 为什么不是基本的方式,for循环遍历epochs和train?

标签: python tensorflow deep-learning neural-network conv-neural-network


【解决方案1】:

错误是因为变量image_batchlabel_batch是张量。提要字典的语法是{tensor1:value1,tensor2:value2.....}。所以你需要输入 numpy 数组来代替value1,value2..

所以你只需要做一个value1,value2 = sess.run([image_batch,label_batch])

总体变化如下:

.
.
for i in range(10000000):

    try:

         raw_images, raw_labels = sess.run([image_batch, label_batch])
        _loss_value, _reward_value, _ = sess.run([loss, reward, train_op], feed_dict={image_batch: raw_images, label_batch: raw_labels})

    except tf.errors.OutOfRangeError:
         print("Breaking...")
         break
         .
         .
    if i % 100 == 0:
         print('iter: ', i, '\tloss: ', _loss_value, '\treward: ', _reward_value)

我认为使用 tf.train.Coordinator() 而不是我写的 try..except 块,您也可以使用以下块(在他们的网站上):

try:
  while not coord.should_stop():
    ...do some work...
except Exception as e:
  coord.request_stop(e)

【讨论】:

    猜你喜欢
    • 2014-05-18
    • 1970-01-01
    • 2018-05-06
    • 2019-12-29
    • 2016-09-27
    • 1970-01-01
    • 1970-01-01
    • 2018-12-27
    • 1970-01-01
    相关资源
    最近更新 更多