【发布时间】:2017-01-14 14:25:42
【问题描述】:
在练习面向初学者的官方 tensorflow mnist 数据集教程时,我正在尝试将 mnist 数据更改为我自己从搜索引擎收集的图像。
strFilePaths,iLabels ,strSubFolderNames,iNumTotalDatasets = ScanForImage('Datasets')
tsFileNameQueue = tf.train.string_input_producer(strFilePaths)
tsReader = tf.WholeFileReader()
_,tsImage = tsReader.read(tsFileNameQueue)
tsImage = tf.image.decode_jpeg(tsImage, channels=3)
tsImage = tf.cast(tsImage,tf.float32)
tsLabels = tf.convert_to_tensor(iLabels, dtype=tf.float32)
tsImage = tf.reshape(tsImage, shape=[1,168*300*3])
matWeights = tf.Variable(tf.random_normal([168*300*3, 2]))
vBiases = tf.Variable(tf.zeros([2]))
vPredictions = tf.nn.softmax(tf.matmul(tsImage, matWeights) + vBiases)
fCrossEntropy = tf.reduce_mean(-tf.reduce_sum(tsLabels * tf.log(vPredictions), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(fCrossEntropy)
init = tf.global_variables_initializer()
with tf.Session() as sess :
sess.run(init)
for i in range (1000) :
tsTrainingSets = tf.train.batch([tsImage,tsLabels], batch_size=100)
sess.run(train_step)
if i % 20 == 0 :
correct_prediction = tf.equal(tf.argmax(vPredictions,1),tf.argmax(tsTrainingSets[1],1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print(sess.run(accuracy))
strFilePaths 是一个标准的 python 列表,包含我所有的图像路径,iLabels 是一个代表标签的列表列表。在这种情况下,我只有 2 个班级。
这个程序运行没有错误输出,但 tensorflow 只是继续运行,没有给我任何输出。我已经阅读了 tensorflow 网站上的“阅读文件”会话大约一千遍,但我仍然不知道我是否做对了。
Q1:这段代码有什么问题? Q2:有没有完整的例子说明如何将jpeg文件读入tensorflow并对其执行一些训练任务?
【问题讨论】:
-
Image Retraining教程使用decode jpeg,即github.com/tensorflow/tensorflow/blob/…
标签: python input tensorflow jpeg