【问题标题】:Restricting the size of MNIST training data限制 MNIST 训练数据的大小
【发布时间】:2019-07-24 00:02:02
【问题描述】:

我刚刚开始学习 python 和 TensorFlow,并且正在尝试各种神经网络和 MNIST 数据。我想做的一个实验是看看训练集的大小如何影响性能。目前,训练集中似乎有 55000 个输入/输出对。我想通过某种方式将训练限制为仅使用前 1000 个左右,但不知道如何实现。

我目前的训练功能是这样的:

def do_training():
    print("Train entry")
    for i in range(2000):

        batch_of_training_inputs, batch_of_training_labels = mnist.train.next_batch(100)

        sess.run(train_step, feed_dict={generic_image_data_struct: batch_of_training_inputs, target_for_output_struct: batch_of_training_labels })

有没有类似...

mnist.train.next_batch(100, BUT_ONLY_FROM_FIRST(1000))

仅供参考,我用这段代码得到了 mnist:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

【问题讨论】:

    标签: python tensorflow mnist


    【解决方案1】:

    您可以做的一件简单的事情就是增加验证数据集的大小。 MNIST 包含 60,000 张图像,因此如果您只想训练 1,000 张图像,您可以这样做:

    mnist = input_data.read_data_sets(train_dir, one_hot=True, validation_size=59000)
    

    【讨论】:

      【解决方案2】:

      通过一些黑客攻击,我认为这可能会奏效。尽管我真的不建议将来依赖此解决方案,因为它取决于 DataSet.__init__ 方法的内部实现以某种方式表现。对于一个快速的实验,它可能没问题。

      from tensorflow.examples.tutorials.mnist import input_data
      from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
      from tensorflow.python.framework import dtypes
      
      mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
      train_small = DataSet(mnist.train.images[:1000], mnist.train.labels[:1000],
                            dtype=dtypes.uint8, reshape=False, seed=None)
      

      【讨论】: