【问题标题】:How to remove last layer in keras subclass model but keep weights?如何删除keras子类模型中的最后一层但保留权重?
【发布时间】:2021-07-17 02:11:16
【问题描述】:

我正在训练一个基于densenet的特征提取器,如下所示:

# Import the Sequential model and layers
from keras.models import Sequential
import keras
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D, Lambda, Dropout, Concatenate
from keras.layers import Activation, Dropout, Flatten, Dense
import pandas as pd
from sklearn import preprocessing
import ast
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint

size = 256

class DenseNetBase(tf.keras.Model):
    
    def __init__(self, size, include_top = True):
        
        super(DenseNetBase, self).__init__()
        
        self.include_top = include_top
        
        #base
        self.base = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
        
        #final layer
        self.dense = Dense(1, activation='sigmoid', name='predictions')
        
    def call(self, input_tensor):
        
        input_image = input_tensor[0]
        input_metafeatures = input_tensor[1]
        
        #model        
        x = self.base(input_image)
        
        if self.include_top:
            x = self.dense(x)
    
        return x
    
    def build_graph(self):
        x = self.base.input
        y = tf.keras.Input(shape=(3,))
        return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))

然后我想采用 DenseNetBase,保留训练后的权重,但删除最终的密集层以用于提取特征。简化后的 DenseClassifier 如下所示:

class DenseClassifier(tf.keras.Model):
    
    def __init__(self, size, feature_extractor):
        
        super(DenseClassifier, self).__init__()
        
        #base tf.keras.layers.Input(shape=(size,size,3))
        self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)             
        
        #final layer
        self.dense = Dense(1, activation='sigmoid', name='prediction')
        
    def call(self, input_tensor):
        
        input_image = input_tensor[0]
        input_metafeatures = input_tensor[1]
        
        #model        
        x = self.feature_extractor(input_image)
        
        return self.dense(x)
    
    def build_graph(self):
        x = self.base.input
        y = tf.keras.Input(shape=(3,))
        return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))

把它绑在一起:

#build densenet feature extractor we have trained
denseBase = DenseNetBase(256, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')

#this doesn't work
DenseClassifier = DenseClassifier(size = 256, feature_extractor = denseBase)

在上面的示例中,我收到一个输入错误,我不知道为什么。预期的行为是我可以构建后一个模型并进行编译,现有的权重 DenseNetBase 将用于特征提取。

我试图用inputs = feature_extractor.layers[-2].input 替换输入部分,它确实可以编译,但即使它使用相同的权重(在上面的简单示例中没有额外的层),它的评估似乎也不能达到与denseBase 相同的精度.

我的目标/问题:

  • 如何从预训练的 denseBase 加载权重,但删除最后一个密集层(因此输出为 (None, 1920),因为 DenseNet 没有顶部但使用我的权重)。
  • 然后我如何才能将此模型不密集地加载到另一个子类模型中以提取特征。

谢谢!

【问题讨论】:

    标签: python tensorflow machine-learning keras


    【解决方案1】:

    为了回答我自己的问题,我使用here 中的逻辑对初始化权重的值进行了一些测试:

    这是预期的。 DenseBaseClassifier(使用denseBase)和使用imagenet权重都具有相似的预测权重初始化。这是因为这两个层都是随机初始化的,没有经过训练,而denseBase中的预测层已经过优化,因此是不同的。

    对于denseNet部分,DenseBaseClassifier(使用denseBase)==denseBase(由于只保存权重而产生一些噪声),而原始imagenet权重不同。

    使用denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output) 确实保留了权重。

    不知道为什么self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output) 不起作用。

    denseBase = DenseNetBase(size, include_top = True)
    denseBase.build([(None, 256, 256, 3), (None,3)])
    denseBase.load_weights('./models/DenseBaseSimple.h5')
    
    denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output)
    DenseClassifier_denseBase = DenseClassifier(size = 256, feature_extractor = denseBase_featureextractor)
    DenseClassifier_denseBase.build([(None, 256, 256, 3), (None,3)])
    
    denseBase_imagenet = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
    DenseClassifier_imagenet = DenseClassifier(size = 256, feature_extractor = denseBase_imagenet)
    DenseClassifier_imagenet.build([(None, 256, 256, 3), (None,3)])
    
    def get_weights_print_stats(layer):
        W = layer.get_weights()
        #print(len(W))
        #for w in W:
        #    print(w.shape)
        return W
    
    def hist_weights(weights, title, bins=500):
        for weight in weights[0:5]:
            plt.hist(np.ndarray.flatten(weight), bins=bins)
            plt.title(title)
    
    fig = plt.figure(figsize=(15, 10))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    
    W = get_weights_print_stats(denseBase.layers[1])
    plt.subplot(2, 3, 1)
    hist_weights(W, "denseBase")
    y = plt.ylabel("Final prediction later weights")#, rotation="horizontal")
    
    W = get_weights_print_stats(DenseClassifier_denseBase.layers[1])
    plt.subplot(2, 3, 2)
    hist_weights(W, "DenseBaseClassifier (using denseBase weights)")
    
    W = get_weights_print_stats(DenseClassifier_imagenet.layers[1])
    plt.subplot(2, 3, 3)
    hist_weights(W, "DenseBaseClassifier (using imagenet weights)")
    
    W = get_weights_print_stats(denseBase.layers[0])
    plt.subplot(2, 3, 4)
    hist_weights(W, "denseBase")
    y = plt.ylabel("DenseNet base first 5 weights")#, rotation="horizontal")
    
    W = get_weights_print_stats(DenseClassifier_denseBase.layers[0])
    plt.subplot(2, 3, 5)
    hist_weights(W, "DenseBaseClassifier (using denseBase weights)")
    
    W = get_weights_print_stats(DenseClassifier_imagenet.layers[0])
    plt.subplot(2, 3, 6)
    hist_weights(W, "DenseBaseClassifier (using imagenet weights)")
    

    【讨论】:

      猜你喜欢
      • 2020-12-16
      • 2022-07-19
      • 2018-03-29
      • 2018-10-15
      • 2021-07-14
      • 2012-05-04
      • 1970-01-01
      • 2019-02-25
      • 1970-01-01
      相关资源
      最近更新 更多