【发布时间】:2021-03-04 02:27:50
【问题描述】:
我正在尝试使用 keras 编写知识蒸馏模型。我从 keras 示例 here 开始。为了训练模型,覆盖了 train_step 和 test_step 方法。与 keras 示例不同,我想使用 ImageDataGenerator 对象来拟合模型,以预处理 CIFAR10 数据集中的图像。问题是,每当我调用传递 X_train 和 Y_train 的 model.fit 函数时,训练工作正常,如果我调用 model.fit 传递 ImageDataGenerator.flow(X_train, Y_train, batch_size) 代码返回以下错误:
NotImplementedError:子类化Model类时,应该实现调用方法。
我也尝试过修改 train_step 处理它接收到的数据输入的方式,但似乎到目前为止还没有任何方法奏效。
为什么会这样?用 ImageDataGenereator 对象覆盖 Model 类的 train_step 方法有什么问题吗?类Model的fit方法也应该被覆盖吗?
为了让事情变得清晰和可重现,这里是示例代码:
import time
import copy
import tensorflow as tf
import keras
from keras import regularizers
from keras.engine import Model
from keras.layers import Dropout, Flatten, Dense, Conv2D, MaxPooling2D, Activation, BatchNormalization
from keras.models import Sequential
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils
from tensorflow.python.keras.engine import data_adapter
# Imported from files
import settings_parser
from utils import progressive_learning_rate
from teacher import Teacher, build_teacher
from student import Student, build_student
class Distiller(tf.keras.Model):
def __init__(self, student, teacher):
super(Distiller, self).__init__()
self.teacher = teacher
self.student = student
def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
# @tf.function
def train_step(self, data):
# Treat data in different ways if it is a tuple or an iterator
x = None
y = None
if isinstance(data, tuple):
x, y = data
if isinstance(data, tf.keras.preprocessing.image.NumpyArrayIterator):
x, y = data.next()
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
# @tf.function
def test_step(self, data):
# Treat data in different ways if it is a tuple or an iterator
x = None
y = None
if isinstance(data, tuple):
x, y = data
if isinstance(data, tf.keras.preprocessing.image.NumpyArrayIterator):
x, y = data.next()
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
#Define method to build the teacher model (VGG16)
def build_teacher():
input = keras.Input(shape=(32, 32, 3), name="img")
x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(input)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
# Block 2
x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 3
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 4
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 5
x = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 6
x = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 7
x = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 8
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 9
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 10
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 11
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 12
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 13
x = Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.5)(x)
# Flatten and classification
x = Flatten()(x)
x = Dense(512)(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
# Out
x = Dense(10)(x)
output = Activation('softmax')(x)
# Define model from input and output
model = keras.Model(input, output, name="teacher")
print(model.summary())
return model
#Define method to build the teacher model (VGG16)
def build_student():
input = keras.Input(shape=(32, 32, 3), name="img")
x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(input)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
# Block 2
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 3
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
# Block 4
x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Block 5
x = Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
# Flatten and classification
x = Flatten()(x)
x = Dense(512)(x)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
# Out
x = Dense(10)(x)
output = Activation('softmax')(x)
# Define model from input and output
model = keras.Model(input, output, name="student")
print(model.summary())
return model
if __name__ == '__main__':
args = settings_parser.arg_parse()
print_during_epochs = True
student = build_student()
student_clone = build_student()
student_clone.set_weights(student.get_weights())
teacher = build_teacher()
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
train_datagen = ImageDataGenerator(
rescale=1. / 255, # rescale input image
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=15, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images)
train_datagen.fit(X_train)
train_generator = train_datagen.flow(X_train, Y_train, batch_size=64)
test_datagen = ImageDataGenerator(rescale=1. / 255)
test_generator = test_datagen.flow(X_test, Y_test, batch_size=64)
# Train teacher as usual
teacher.compile(optimizer=keras.optimizers.SGD(),
loss=keras.losses.categorical_crossentropy,
metrics=['accuracy'])
# Train and evaluate teacher on data.
teacher.fit(train_generator, validation_data=test_generator, epochs=5, verbose=print_during_epochs)
loss, acc = teacher.evaluate(test_generator)
print("Teacher model, accuracy: {:5.2f}%".format(100 * acc))
# Train student as doen usually
student_clone.compile(optimizer=keras.optimizers.SGD(),
loss=keras.losses.categorical_crossentropy,
metrics=['accuracy'])
# Train and evaluate student trained from scratch.
student_clone.fit(train_generator, validation_data=test_generator, epochs=5, verbose=print_during_epochs)
loss, acc = student_clone.evaluate(test_generator)
print("Student scratch model, accuracy: {:5.2f}%".format(100 * acc))
#print('{}\n\n{}'.format(teacher.summary(), student_clone.summary()))
# Train student using knowledge distillation
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(optimizer=keras.optimizers.SGD(),
metrics=['accuracy'],
student_loss_fn=keras.losses.CategoricalCrossentropy(), # categorical_crossentropy,
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10)
# Distill teacher to student
distiller.fit(X_train, Y_train, epochs=5) #THIS WORKS FINE
distiller.fit(train_generator, validation_data=test_generator, epochs=5,
verbose=print_during_epochs) # THIS DOESN'T WORK
# Evaluate student on test dataset
loss, acc = distiller.evaluate(test_generator)
print("Student distilled model, accuracy: {:5.2f}%".format(100 * acc))
【问题讨论】:
-
你能添加完整的回溯吗?
标签: tensorflow keras deep-learning training-data model-fitting