【问题标题】:tensorflow dataset shuffle examples instead of batchestensorflow 数据集 shuffle 示例而不是批处理
【发布时间】:2018-08-24 22:18:46
【问题描述】:

如何以批处理模式获取 TensorFlow 数据集以对所有样本进行随机播放?它只是洗牌批次。

下面是一个程序,它制作了一个包含 1000 个项目的数据集,并以 5 个批次经历了 10 个 epoch。我打开了shuffle()。我可以看到 tensorflow 将数据集分为 200 批,每批 5 个示例,并且洗牌跨越这些批。我希望每个新批次都是原始 1000 个样本的随机样本,而不是 200 个原始批次的样本。

也就是这个程序:

import numpy as np
import tensorflow as tf
import random


def rec2tfrec_example(rec):
    def _int64_feat(value):
        arr_value = np.empty([1], dtype=np.int64)
        arr_value[0] = value
        return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_value))

    feat = {
        'uid': _int64_feat(rec['uid']),
    }

    return tf.train.Example(features=tf.train.Features(feature=feat)).SerializeToString()


def parse_example(tfrec_serialized_string):
    feat = {
        'uid': tf.FixedLenFeature([], tf.int64),
    }
    return tf.parse_example(tfrec_serialized_string, feat)


def write_tfrecs_to_file(fname, recs):
        recwriter = tf.python_io.TFRecordWriter(fname)
        for rec in recs:
            recwriter.write(bytes(rec))
        recwriter.close()


def check_shuffle(sess, tfrec_output_filename, data, N, batch_size):
    epochs = 10
    dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
                     .batch(batch_size) \
                     .repeat(epochs) \
                     .shuffle(2*N) \
                     .map(parse_example, num_parallel_calls=2)
    tf_iter = dataset.make_initializable_iterator()
    get_next = tf_iter.get_next()

    sess.run(tf_iter.initializer)
    num_batches = N//batch_size
    for epoch in range(epochs ):
        for batch in range(N//batch_size):
            tfres = sess.run(get_next)
            print("epoch=%4d batch=%d uid=%s" % (epoch, batch, tfres['uid']))


def main(N=1000, batch_size=5, tfrec_output_filename='tfrec_testing.tfrecords'):
    tf.reset_default_graph()
    data = [{'uid': uid } for uid in range(N)]
    tfrec_strings = [rec2tfrec_example(rec) for rec in data]
    write_tfrecs_to_file(tfrec_output_filename, tfrec_strings)
    with tf.Session() as sess:
        check_shuffle(sess, tfrec_output_filename, data, N, batch_size)

if __name__ == '__main__':
    main()

产生如下输出:

epoch=   9 batch=186 uid=[685 686 687 688 689]
epoch=   9 batch=187 uid=[235 236 237 238 239]
epoch=   9 batch=188 uid=[520 521 522 523 524]
epoch=   9 batch=189 uid=[135 136 137 138 139]
epoch=   9 batch=190 uid=[95 96 97 98 99]
epoch=   9 batch=191 uid=[290 291 292 293 294]
epoch=   9 batch=192 uid=[230 231 232 233 234]
epoch=   9 batch=193 uid=[215 216 217 218 219]

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    啊,batch 和 shuffle 的顺序很重要,如果我像这样设置数据集

    dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
                     .shuffle(2*N) \
                     .batch(batch_size) \
                     .repeat(epochs) \
                     .map(parse_example, num_parallel_calls=2)
    

    在批处理前使用随机播放,然后它就可以工作了。

    【讨论】:

      猜你喜欢
      • 2018-10-30
      • 2018-07-21
      • 1970-01-01
      • 1970-01-01
      • 2017-10-22
      • 2018-08-17
      • 2020-10-29
      • 2018-09-03
      • 1970-01-01
      相关资源
      最近更新 更多