假设您不想维持tf.keras.datasets.mnist API 提供的训练和测试之间的默认拆分,您可以将训练和测试集添加到一起,然后根据您的比率将它们迭代地拆分为训练、验证和测试。
from sklearn.model_selection import train_test_split
import tensorflow as tf
DATASET_SIZE = 70000
TRAIN_RATIO = 0.7
VALIDATION_RATIO = 0.2
TEST_RATIO = 0.1
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
X = np.concatenate([x_train, x_test])
y = np.concatenate([y_train, y_test])
如果您希望数据集是 numpy 数组,您可以使用 sklearn.model_selection import train_test_split() 函数。
举个例子:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=(1-TRAIN_RATIO))
X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, test_size=((TEST_RATIO/(VALIDATION_RATIO+TEST_RATIO))))
如果您更喜欢使用 tf Dataset API,那么您可以使用 .take() 和 .skip() 方法,如下所示:
dataset = tf.data.Dataset.from_tensor_slices((X, y))
train_dataset = dataset.take(int(TRAIN_RATIO*DATASET_SIZE))
validation_dataset = dataset.skip(int(TRAIN_RATIO*DATASET_SIZE)).take(int(VALIDATION_RATIO*DATASET_SIZE))
test_dataset = dataset.skip(int(TRAIN_RATIO*DATASET_SIZE)).skip(int(VALIDATION_RATIO*DATASET_SIZE))
此外,您可以在拆分之前将.shuffle() 添加到您的数据集以生成混洗分区:
dataset = dataset.shuffle()