【发布时间】:2021-07-17 02:11:16
【问题描述】:
我正在训练一个基于densenet的特征提取器,如下所示:
# Import the Sequential model and layers
from keras.models import Sequential
import keras
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D, Lambda, Dropout, Concatenate
from keras.layers import Activation, Dropout, Flatten, Dense
import pandas as pd
from sklearn import preprocessing
import ast
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
size = 256
class DenseNetBase(tf.keras.Model):
def __init__(self, size, include_top = True):
super(DenseNetBase, self).__init__()
self.include_top = include_top
#base
self.base = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
#final layer
self.dense = Dense(1, activation='sigmoid', name='predictions')
def call(self, input_tensor):
input_image = input_tensor[0]
input_metafeatures = input_tensor[1]
#model
x = self.base(input_image)
if self.include_top:
x = self.dense(x)
return x
def build_graph(self):
x = self.base.input
y = tf.keras.Input(shape=(3,))
return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))
然后我想采用 DenseNetBase,保留训练后的权重,但删除最终的密集层以用于提取特征。简化后的 DenseClassifier 如下所示:
class DenseClassifier(tf.keras.Model):
def __init__(self, size, feature_extractor):
super(DenseClassifier, self).__init__()
#base tf.keras.layers.Input(shape=(size,size,3))
self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)
#final layer
self.dense = Dense(1, activation='sigmoid', name='prediction')
def call(self, input_tensor):
input_image = input_tensor[0]
input_metafeatures = input_tensor[1]
#model
x = self.feature_extractor(input_image)
return self.dense(x)
def build_graph(self):
x = self.base.input
y = tf.keras.Input(shape=(3,))
return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))
把它绑在一起:
#build densenet feature extractor we have trained
denseBase = DenseNetBase(256, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')
#this doesn't work
DenseClassifier = DenseClassifier(size = 256, feature_extractor = denseBase)
在上面的示例中,我收到一个输入错误,我不知道为什么。预期的行为是我可以构建后一个模型并进行编译,现有的权重 DenseNetBase 将用于特征提取。
我试图用inputs = feature_extractor.layers[-2].input 替换输入部分,它确实可以编译,但即使它使用相同的权重(在上面的简单示例中没有额外的层),它的评估似乎也不能达到与denseBase 相同的精度.
我的目标/问题:
- 如何从预训练的 denseBase 加载权重,但删除最后一个密集层(因此输出为 (None, 1920),因为 DenseNet 没有顶部但使用我的权重)。
- 然后我如何才能将此模型不密集地加载到另一个子类模型中以提取特征。
谢谢!
【问题讨论】:
标签: python tensorflow machine-learning keras