【问题标题】:Mnist dataset splittingMnist 数据集拆分
【发布时间】:2021-04-08 20:47:43
【问题描述】:

任何人都可以帮助我按照我们希望的比率将 mnist 数据集拆分为训练、测试和验证。

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

使用 70-20-10 拆分进行训练、验证和测试。

【问题讨论】:

  • 嗨@TaranjeetSinh!请查看我的答案,如果这是您要找的,请考虑接受它作为正确的答案

标签: tensorflow machine-learning keras training-data mnist


【解决方案1】:

这种方法应该可以做到。它基本上迭代地使用来自 tensorflow 的 train_test_split 函数将数据集拆分为验证测试​​训练:

train_ratio = 0.70
validation_ratio = 0.20
test_ratio = 0.10

# train is now 70% of the entire data set
# the _junk suffix means that we drop that variable completely
x_train, x_test, y_train, y_test = train_test_split(dataX, dataY, test_size=1 - train_ratio)

# test is now 10% of the initial data set
# validation is now 20% of the initial data set
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio/(test_ratio + validation_ratio)) 

【讨论】:

    【解决方案2】:

    假设您不想维持tf.keras.datasets.mnist API 提供的训练和测试之间的默认拆分,您可以将训练和测试集添加到一起,然后根据您的比率将它们迭代地拆分为训练、验证和测试。

    from sklearn.model_selection import train_test_split
    import tensorflow as tf
    
    DATASET_SIZE = 70000
    TRAIN_RATIO = 0.7
    VALIDATION_RATIO = 0.2
    TEST_RATIO = 0.1
    
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    
    X = np.concatenate([x_train, x_test])
    y = np.concatenate([y_train, y_test])
    

    如果您希望数据集是 numpy 数组,您可以使用 sklearn.model_selection import train_test_split() 函数。 举个例子:

    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=(1-TRAIN_RATIO))
    X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, test_size=((TEST_RATIO/(VALIDATION_RATIO+TEST_RATIO))))
    

    如果您更喜欢使用 tf Dataset API,那么您可以使用 .take().skip() 方法,如下所示:

    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    
    train_dataset = dataset.take(int(TRAIN_RATIO*DATASET_SIZE))
    validation_dataset = dataset.skip(int(TRAIN_RATIO*DATASET_SIZE)).take(int(VALIDATION_RATIO*DATASET_SIZE))
    test_dataset = dataset.skip(int(TRAIN_RATIO*DATASET_SIZE)).skip(int(VALIDATION_RATIO*DATASET_SIZE))
    

    此外,您可以在拆分之前将.shuffle() 添加到您的数据集以生成混洗分区:

    dataset = dataset.shuffle()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-09-04
      • 1970-01-01
      • 2021-09-07
      • 1970-01-01
      • 2019-05-14
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多