【发布时间】:2017-05-24 07:15:55
【问题描述】:
我正在尝试使用 Tensorflow 编写自己的 MNIST 数字分类器,但我遇到了 tf.train.shuffle_batch 函数的奇怪行为。
当我尝试从不同文件加载图像和标签时出现问题,shuffle batch 似乎会自行对标签和图像进行混洗,因此会产生错误的标记数据。数据取自here
这是 shuffle_batch 函数的定义行为吗? 当数据和标签是不同的文件时,您建议如何处理这种情况?
这是我的代码
DATA = 'train-images.idx3-ubyte'
LABELS = 'train-labels.idx1-ubyte'
data_queue = tf.train.string_input_producer([DATA,])
label_queue = tf.train.string_input_producer([LABELS,])
NUM_EPOCHS = 2
BATCH_SIZE = 10
reader_data = tf.FixedLengthRecordReader(record_bytes=28*28, header_bytes = 16)
reader_labels = tf.FixedLengthRecordReader(record_bytes=1, header_bytes = 8)
(_,data_rec) = reader_data.read(data_queue)
(_,label_rec) = reader_labels.read(label_queue)
image = tf.decode_raw(data_rec, tf.uint8)
image = tf.reshape(image, [28, 28, 1])
label = tf.decode_raw(label_rec, tf.uint8)
label = tf.reshape(label, [1])
image_batch, label_batch = tf.train.shuffle_batch([image, label],
batch_size=BATCH_SIZE,
capacity=100,
min_after_dequeue = 30)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
image = image_batch[1]
im = image.eval()
print("im_batch shape :" + str(image_batch.get_shape().as_list()))
print("label shape :" + str(label_batch.get_shape().as_list()))
print("label is :" + str(label_batch[1].eval()))
# print("output is :" + str(conv1.eval()))
plt.imshow(np.reshape(im, [-1, 28]), cmap='gray')
plt.show()
coord.request_stop()
coord.join(threads)
【问题讨论】:
标签: python tensorflow deep-learning