【问题标题】:Training MNIST data set in google colab issue: [duplicate]在 google colab 问题中训练 MNIST 数据集:[重复]
【发布时间】:2020-09-10 04:13:19
【问题描述】:

我在专业版的 google colab notebook 中执行 CNN。虽然 x_train 的形状是 (60,000, 28,28)。该模型仅在 1875 行上进行了训练。以前有人遇到过这个问题吗?我的模型在本地机器的 jupyter notebook 上运行良好。它在所有 60,000 行上运行

    import tensorflow as tf
    mnist = tf.keras.datasets.mnist

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    y_train = y_train.astype('float32') / 255.0

    print("x_train.shape:", x_train.shape)

    #Build the model
    from tensorflow.keras.layers import Dense, Flatten, Dropout
    model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28,28)),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation='softmax')
    ])

    r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs = 10)


    Output:

    x_train.shape: (60000, 28, 28)

    Epoch 1/10
    1875/1875 [==============================] - 3s 2ms/step - loss: 2.2912e-06 - accuracy:                            0.0987 - val_loss: 7716.5078 - val_accuracy: 0.0980

【问题讨论】:

    标签: python tensorflow machine-learning keras deep-learning


    【解决方案1】:

    1875 是批次数。默认情况下,批次包含 32 个样本。
    60000 / 32 = 1875

    【讨论】:

    • 我认为 32 是批次数,每批次存在 32 次(1875)= 60,000。但我想在整个数据集上进行训练。不想分批。我可以这样做吗?在我的本地机器上完成了 60,000 行的训练。问题出在 google colab notebook 上。
    • 好吧,我想在fit 函数中设置batch_size=60000 是唯一的方法
    • @Naveen 你在整个数据集上训练,并将其分成批次是这样做的标准方法。设置batch_size=60000 很可能会导致内存问题,这不是解决问题的方法。
    【解决方案2】:

    如果你使用keras,而不是tensorflow.keras,日志会显示:

    x_train.shape: (60000, 28, 28)
    Train on 60000 samples, validate on 10000 samples
    Epoch 1/10
    60000/60000 [==============================] - 6s 107us/step - loss: 0.9655 - val_loss: 20.2422
    

    但是两者内部是一样的,一个显示要训练的样本数(keras),另一个显示迭代次数(tf.keras)。

    您不可能一次训练所有 60000 个样本,我们需要对输入进行批处理,以免 GPU 内存不足。您可以尝试尽可能多地增加您的batch_size,但过了一段时间您会收到诸如 OOMError、CUDA 内存不足等错误。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-06-08
      • 2019-08-11
      • 2019-06-10
      • 2018-07-06
      • 2015-04-21
      • 2019-05-04
      • 2021-05-17
      • 1970-01-01
      相关资源
      最近更新 更多