更新
根据您的查询,Dataset2 中的班级编号似乎不会有所不同。同时,您也不想使用图像净重。因此,在这种情况下,您不需要映射或存储权重(如下所述)。只需加载模型和权重并在 Dataset2 上进行训练。冻结 Dataset1 中的所有训练层,并在 Dataset2 上训练最后一层;非常直接。
在我的以下回复中,尽管您不需要完整的信息,但我仍将其保留以供将来参考。
这里是您可能需要的一个小演示。希望它能给你一些见解。在这里,我们将训练具有10 类的CIRFAR 数据集,并尝试将其用于可能具有不同输入大小和不同类数量的不同数据集的迁移学习。
准备 CIFAR(10 类)
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# train set / data
x_train = x_train.astype('float32') / 255
# validation set / data
x_test = x_test.astype('float32') / 255
# train set / target
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
# validation set / target
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
'''
(50000, 32, 32, 3) (50000, 10)
(10000, 32, 32, 3) (10000, 10)
'''
型号
# declare input shape
input = tf.keras.Input(shape=(32,32,3))
# Block 1
x = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")(input)
x = tf.keras.layers.MaxPooling2D(3)(x)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(x)
# Finally, we add a classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(input, output)
'''
Model: "functional_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 15, 15, 32) 896
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 32) 0
_________________________________________________________________
global_max_pooling2d_1 (Glob (None, 32) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 330
=================================================================
Total params: 1,226
Trainable params: 1,226
Non-trainable params: 0
'''
运行模型得到一些权重矩阵如下:
# compile
print('\nFunctional API')
func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# fit
func_model.fit(x_train, y_train, batch_size=128, epochs=1)
迁移学习
让我们将它用于 MNIST。它也有 10 类,但为了需要不同数量的类,我们将从它创建 even 和 odd 类别(2 类)。下面我们将如何准备这些数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# train set / data
x_train = np.expand_dims(x_train, axis=-1)
x_train = np.repeat(x_train, 3, axis=-1)
x_train = x_train.astype('float32') / 255
# train set / target
y_train = tf.keras.utils.to_categorical((y_train % 2 == 0).astype(int),
num_classes=2)
# validation set / data
x_test = np.expand_dims(x_test, axis=-1)
x_test = np.repeat(x_test, 3, axis=-1)
x_test = x_test.astype('float32') / 255
# validation set / target
y_test = tf.keras.utils.to_categorical((y_test % 2 == 0).astype(int),
num_classes=2)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
'''
(60000, 28, 28, 3) (60000, 2)
(10000, 28, 28, 3) (10000, 2)
'''
如果您熟悉 ImageNet 预训练权重在 keras 模型中的用法,您可能会使用 include_top。通过将其设置为False,我们可以轻松加载没有预训练模型顶部信息的权重文件。所以在这里我们需要手动(有点)这样做。我们需要获取权重矩阵,直到最后一个激活层(在我们的例子中是Dense(10, softmax))。并将其放入基础模型的新实例中,然后我们添加一个新的分类器层(在我们的例子中为Dense(2, softmax)。
for i, layer in enumerate(func_model.layers):
print(i,'\t',layer.trainable,'\t :',layer.name)
'''
Train_Bool : Layer Names
0 True : input_1
1 True : conv2d
2 True : max_pooling2d
3 True : global_max_pooling2d # < we go till here to grab the weight and biases
4 True : dense # 10 classes (from previous model)
'''
获取权重
sparsified_weights = []
for w in func_model.get_layer(name='global_max_pooling2d').get_weights():
sparsified_weights.append(w)
由此,我们映射旧模型的权重,分类器层除外 (Dense)。请注意,这里我们抓取到GAP 层的权重,它就在分类器之前。
现在,我们将创建一个新模型,除了最后一层 (10 Dense) 与旧模型相同,同时添加一个新的 Dense 和 2 单元。
predictions = Dense(2, activation='softmax')(func_model.layers[-2].output)
new_func_model = Model(inputs=func_model.inputs, outputs = predictions)
现在我们可以为新模型设置权重如下:
new_func_model.get_layer(name='global_max_pooling2d').set_weights(sparsified_weights)
您可以检查验证如下;除了最后一层之外,所有的都是一样的。
func_model.get_weights() # last layer, Dense (10)
new_func_model.get_weights() # last layer, Dense (2)
现在您可以使用新数据集训练模型,在我们的例子中是 MNIST
new_func_model.compile(optimizer='adam', loss='categorical_crossentropy')
new_func_model.summary()
'''
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 15, 15, 32) 896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 5, 5, 32) 0
_________________________________________________________________
global_max_pooling2d (Global (None, 32) 0
_________________________________________________________________
dense_6 (Dense) (None, 2) 66
=================================================================
Total params: 962
Trainable params: 962
Non-trainable params: 0
'''
# compile
print('\nFunctional API')
new_func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# fit
new_func_model.fit(x_train, y_train, batch_size=128, epochs=1)
WARNING:tensorflow:Model was constructed with shape (None, 32, 32, 3) for input Tensor("input_1:0", shape=(None, 32, 32, 3), dtype=float32), but it was called on an input with incompatible shape (None, 28, 28, 3).
WARNING:tensorflow:Model was constructed with shape (None, 32, 32, 3) for input Tensor("input_1:0", shape=(None, 32, 32, 3), dtype=float32), but it was called on an input with incompatible shape (None, 28, 28, 3).
469/469 [==============================] - 1s 3ms/step - loss: 0.6453 - categorical_accuracy: 0.6447
<tensorflow.python.keras.callbacks.History at 0x7f7af016feb8>