下面是在TensorFlow 的CIFAR10 数据集上使用KFold 交叉验证的代码示例。
获取数据
import tensorflow as tf
import numpy as np
(input, target), (_, _) = tf.keras.datasets.cifar10.load_data()
# Parse numbers as floats
input = input.astype('float32') / 255
target = tf.keras.utils.to_categorical(target , num_classes=10)
print(input.shape, target.shape)
# (50000, 32, 32, 3) (50000, 10)
型号
def my_model():
return tf.keras.Sequential(
[
tf.keras.Input(shape=(32, 32, 3)),
tf.keras.layers.Conv2D(16, 3, activation="relu"),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Conv2D(64, 3, activation="relu"),
tf.keras.layers.Conv2D(128, 3, activation="relu"),
tf.keras.layers.Conv2D(256, 3, activation="relu"),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
]
)
我们将在KFold 循环中称这个坏男孩。
K 折训练
from sklearn.model_selection import KFold
import numpy as np
for kfold, (train, test) in enumerate(KFold(n_splits=3,
shuffle=True).split(input, target)):
# clear the session
tf.keras.backend.clear_session()
# calling the model and compile it
seq_model = my_model()
seq_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
print('Train Set')
print(input[train].shape)
print(target[train].shape)
print('Test Set')
print(input[test].shape)
print(target[test].shape)
# run the model
seq_model.fit(input[train], target[train],
batch_size=128, epochs=2, validation_data=(input[test], target[test]))
seq_model.save_weights(f'wg_{kfold}.h5')
日志
Train Set
(33333, 32, 32, 3)
(33333, 10)
Test Set
(16667, 32, 32, 3)
(16667, 10)
Epoch 1/2
11s 41ms/step - loss: 1.9961 - categorical_accuracy: 0.2363 -
val_loss: 1.6851 - val_categorical_accuracy: 0.3435
Epoch 2/2
10s 37ms/step - loss: 1.6322 - categorical_accuracy: 0.3836 -
val_loss: 1.5780 - val_categorical_accuracy: 0.4193
Train Set
(33333, 32, 32, 3)
(33333, 10)
Test Set
(16667, 32, 32, 3)
(16667, 10)
Epoch 1/2
11s 39ms/step - loss: 2.0254 - categorical_accuracy: 0.2197 -
val_loss: 1.6799 - val_categorical_accuracy: 0.3601
Epoch 2/2
10s 37ms/step - loss: 1.6687 - categorical_accuracy: 0.3739 -
val_loss: 1.5222 - val_categorical_accuracy: 0.4362
Train Set
(33334, 32, 32, 3)
(33334, 10)
Test Set
(16666, 32, 32, 3)
(16666, 10)
Epoch 1/2
11s 41ms/step - loss: 2.0170 - categorical_accuracy: 0.2212 -
val_loss: 1.7452 - val_categorical_accuracy: 0.3134
Epoch 2/2
10s 37ms/step - loss: 1.7110 - categorical_accuracy: 0.3363 -
val_loss: 1.5928 - val_categorical_accuracy: 0.4164
更新
使用pandas 数据框的另一种方法。
# load
df = pd.read_csv('train.csv')
# [optional]: shuffle the data
df = df.sample(frac=1).reset_index(drop=True)
# fold split
kfold = KFold(n_splits=3, shuffle=True)
for each_fold, (trn_idx, val_idx) in enumerate(.split(np.arange(df.shape[0]),
df.target.values)):
# get the folded data
train_labels = df.iloc[trn_idx].reset_index(drop=True)
val_labels = df.iloc[val_idx].reset_index(drop=True)
# train_labels: training pairs (data + label)
# do something
让我们再举一个多标签数据的例子。这次我们将首先创建一个折叠列。请参阅下面的示例。
# !pip install iterative-stratification
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
# load the data
df = pd.read_csv('train.csv')
# add a new column, for now setting all value to -1
# we will update it later
df.loc[:, 'kfold'] = -1
# [optional]: shuffle
df = df.sample(frac=1).reset_index(drop=True)
# as we are now dealing with multi label cases,
# there would be more than one column for target value
# so let's grab the target values only
# we will drop id / image_id / image_name / blabla.. column
target = df.drop('id', axis-1).values
mskf = MultilabelStratifiedKFold(n_splits=5)
for each_kfold, (trn_idx, val_idx) in enumerate(mskf.split(df, target):
df.loc[val_idx, 'kfold] = each_fold # updating `kfold` column
之后,n_split = 5 将出现在“kfold”列中,同时检查df.head()。接下来,我们可以做如下的事情
def program(fold):
# get the folded data
train_labels = df[df.kfold != fold].reset_index(drop=True)
val_labels = df[df.kfold == fold].reset_index(drop=True)
# train_labels: training pairs (data + label)
# do something
for i in range(n_split):
program(fold=i)