【问题标题】:how to read batches in one hdf5 data file for training?如何读取一个 hdf5 数据文件中的批次进行训练?
【发布时间】:2016-07-06 13:52:03
【问题描述】:

我有一个大小为(21760, 1, 33, 33) 的 hdf5 训练数据集。 21760 是训练样本的总数。我想使用大小为128 的小批量训练数据来训练网络。

我想问:

如何每次用tensorflow从整个数据集中提供128 mini-batch 训练数据?

【问题讨论】:

    标签: python tensorflow deep-learning


    【解决方案1】:

    如果你的数据集太大以至于无法像keveman建议的那样导入内存,你可以直接使用h5py对象:

    import h5py
    import tensorflow as tf
    
    data = h5py.File('myfile.h5py', 'r')
    data_size = data['data_set'].shape[0]
    batch_size = 128
    sess = tf.Session()
    train_op = # tf.something_useful()
    input = # tf.placeholder or something
    for i in range(0, data_size, batch_size):
        current_data = data['data_set'][position:position+batch_size]
        sess.run(train_op, feed_dict={input: current_data})
    

    如果您愿意,还可以运行大量迭代并随机选择一个批次:

    import random
    for i in range(iterations):
        pos = random.randint(0, int(data_size/batch_size)-1) * batch_size
        current_data = data['data_set'][pos:pos+batch_size]
        sess.run(train_op, feed_dict={inputs=current_data})
    

    或顺序:

    for i in range(iterations):
        pos = (i % int(data_size / batch_size)) * batch_size
        current_data = data['data_set'][pos:pos+batch_size]
        sess.run(train_op, feed_dict={inputs=current_data})
    

    您可能想要编写一些更复杂的代码,随机遍历所有数据,但跟踪已使用的批次,因此您不会比其他批次更频繁地使用任何批次。完成训练集的完整运行后,再次启用所有批次并重复。

    【讨论】:

    • 这种方法在逻辑上似乎是正确的,但我使用它并没有得到任何积极的结果。我最好的猜测是:使用上面的代码示例 1,在每次迭代中,网络都会重新训练,忘记在前一个循环中学到的所有内容。因此,如果我们每次迭代获取 30 个样本或批次,在每次循环/迭代中,仅使用 30 个数据样本,然后在下一个循环中,所有内容都会被覆盖。
    【解决方案2】:

    您可以将 hdf5 数据集读入 numpy 数组,并将 numpy 数组的切片提供给 TensorFlow 模型。像下面这样的伪代码可以工作:

    import numpy, h5py
    f = h5py.File('somefile.h5','r')
    data = f.get('path/to/my/dataset')
    data_as_array = numpy.array(data)
    for i in range(0, 21760, 128):
      sess.run(train_op, feed_dict={input:data_as_array[i:i+128, :, :, :]})
    

    【讨论】:

    • 谢谢。但是当训练迭代次数i 很大时,例如10万,怎么养?
    • 如果你只有21760训练样本,你只有21760/128不同的小批量。您必须围绕 i 循环编写一个外部循环,并在训练数据集上运行许多 epoch。
    • 我有一点令人困惑。当原始数据被shuffle然后提取mini-batch时,是否意味着mini-batch的个数超过了21760/128
    【解决方案3】:

    alkamen's 方法在逻辑上似乎是正确的,但我没有得到任何积极的结果。我最好的猜测是:使用上面的代码示例 1,在每次迭代中,网络都会重新训练,忘记在前一个循环中学到的所有内容。因此,如果我们每次迭代获取 30 个样本或批次,在每个循环/迭代中,仅使用 30 个数据样本,然后在下一个循环中,所有内容都会被覆盖。

    在下面找到这种方法的屏幕截图

    可以看出,损失和准确性总是重新开始。如果有人可以分享解决此问题的可能方法,我会很高兴。

    【讨论】:

    • 您在其他用户中添加了标签,我的名字拼写为“n”,而不是“m”=)
    • 您的准确性不会被重置,它会随着每次迭代而提高,它不会回到零。您确定每次获取批次时都会获得一个全新的批次,并且它们没有高度重叠吗?这可以解释为什么您的准确性最初会提高这么多,因为您基本上为每次迭代重复使用相同的训练数据。然后,当您重置数据并获取新批次时,您可能会再次随机化并获得一组新的重叠批次,其中包含您的网络以前从未见过的数据。
    • 感谢您的评论。是的,我每次都按照我的算法获取新批次,是的,数据被打乱了,但这就是我最终得到的结果(我可能错了),但我感觉我之前的答案就是正在发生的事情。我会继续环顾四周。如果我确实找到任何东西,我会很乐意分享。还有.....对不起,我没记好你的名字。谢谢你的时间。干杯!
    • 好的。如果它确实重置并且您确定您的批次没有重叠,则可能不是数据获取错误,而是模型权重处理。我希望你能找到问题,祝你好运。
    • @CAta.RAy 不幸的是,我在 TensorFlow 上没有任何运气。我创建了一个 github Gist 代码来帮助您理解。所以我转而使用 keras。我构建了一个自定义生成器,用于批量获取数据。请在此处找到它 (gist.github.com/rocksyne/a4022afd7a5aaacdfb873218dba21d0c)。这个函数被称为 Kera 的 fit_generator 函数 (pyimagesearch.com/2018/12/24/…) 如果您可以分享更多您正在做的事情,我可以更好地理解您提供更多量身定制的答案。
    猜你喜欢
    • 1970-01-01
    • 2018-02-08
    • 2020-06-19
    • 1970-01-01
    • 2017-11-09
    • 2018-12-27
    • 2018-05-06
    • 2014-09-21
    • 2015-11-02
    相关资源
    最近更新 更多