不确定这是否正是您要寻找的,但我正在两个单独的循环中使用相同的代码进行训练和验证。我的代码从 .CSV 文件中读取数字和字符串数据,而不是图像。我正在读取两个单独的 CSV 文件,一个用于培训,一个用于验证。我相信您可以将其概括为从两个“组”文件中读取,而不仅仅是单个文件,因为代码就在那里。
这里是代码 sn-ps 以防万一。请注意,此代码首先将所有内容读取为字符串,然后将必要的单元格转换为浮点数,这只是我自己的要求。如果您的数据是纯数字的,您应该将默认值设置为浮点数,一切都应该更容易。此外,其中有几行将权重和偏差放入 CSV 文件并将它们序列化到 TF 检查点文件中,具体取决于您喜欢的方式。
#first define the defaults:
rDefaults = [['a'] for row in range((TD+TS+TL))]
# this function reads line-by-line from CSV and separates cells into chunks:
def read_from_csv(filename_queue):
reader = tf.TextLineReader(skip_header_lines=False)
_, csv_row = reader.read(filename_queue)
data = tf.decode_csv(csv_row, record_defaults=rDefaults)
dateLbl = tf.slice(data, [0], [TD])
features = tf.string_to_number(tf.slice(data, [TD], [TS]), tf.float32)
label = tf.string_to_number(tf.slice(data, [TD+TS], [TL]), tf.float32)
return dateLbl, features, label
#this function loads the above lines and spits them out as batches of N:
def input_pipeline(fName, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
[fName],
num_epochs=num_epochs,
shuffle=True)
dateLbl, features, label = read_from_csv(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size # max of how much to load into memory
dateLbl_batch, feature_batch, label_batch = tf.train.shuffle_batch(
[dateLbl, features, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
return dateLbl_batch, feature_batch, label_batch
# These are the TRAINING features, labels, and meta-data to be loaded from the train file:
dateLbl, features, labels = input_pipeline(fileNameTrain, batch_size, try_epochs)
# These are the TESTING features, labels, and meta-data to be loaded from the test file:
dateLblTest, featuresTest, labelsTest = input_pipeline(fileNameTest, batch_size, 1) # 1 epoch here regardless of training
# then you define the model, start the session, blah blah
# fire up the queue:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#This is the TRAINING loop:
try:
while not coord.should_stop():
dateLbl_batch, feature_batch, label_batch = sess.run([dateLbl, features, labels])
_, acc, summary = sess.run([train_step, accuracyTrain, merged_summary_op], feed_dict={x: feature_batch, y_: label_batch,
keep_prob: dropout,
learning_rate: lRate})
except tf.errors.OutOfRangeError: # (so done reading the file(s))
# by the way, this dumps weights and biases into a CSV file, since you asked for that
np.savetxt(fPath + fIndex + '_weights.csv', sess.run(W),
# and this serializes weight and biases into the TF-formatted protobuf:
# tf.train.Saver({'varW': W, 'varB': b}).save(sess, fileNameCheck)
finally:
coord.request_stop()
# now re-start the runners for the testing file:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
# so now this line reads features, labels, and meta-data, but this time from the training file:
dateLbl_batch, feature_batch, label_batch = sess.run([dateLblTest, featuresTest, labelsTest])
guessY = tf.argmax(y, 1).eval({x: feature_batch, keep_prob: 1})
trueY = tf.argmax(label_batch, 1).eval()
accuracy = round(tf.reduce_mean(tf.cast(tf.equal(guessY, trueY), tf.float32)).eval(), 2)
except tf.errors.OutOfRangeError:
acCumTest /= i
finally:
coord.request_stop()
coord.join(threads)
这可能与您尝试执行的操作不同,因为它首先完成训练循环,然后重新启动测试循环的队列。如果您想返回第四个,不确定如何执行此操作,但您可以尝试通过交替传递相关文件名(或列表)来尝试使用上面定义的两个函数。
我也不确定训练后重新开始排队是否是最好的方法,但它对我有用。希望看到一个更好的例子,因为大多数 TF 例子都使用一些内置的 MNIST 数据集的包装器来一次性完成训练......