【问题标题】:keras setting trainable flag on pretrained modelkeras 在预训练模型上设置可训练标志
【发布时间】:2023-03-13 04:15:01
【问题描述】:

假设我有一个模型

from tensorflow.keras.applications import DenseNet201

base_model = DenseNet201(input_tensor=Input(shape=basic_shape))

model = Sequential()
model.add(base_model)

model.add(Dense(400))
model.add(BatchNormalization())
model.add(ReLU())

model.add(Dense(50, activation='softmax'))

model.save('test.hdf5')

然后我加载保存的模型并尝试使最后 40 层 DenseNet201 可训练,前 161 层 - 不可训练:

saved_model = load_model('test.hdf5')
cnt = 44
saved_model.trainable = False
  while cnt > 0:
      saved_model.layers[-cnt].trainable = True
      cnt -= 1

但这实际上不起作用,因为DenseNet201 被确定为单层,我只是得到索引超出范围错误。

Layer (type)                 Output Shape              Param #   
=================================================================
densenet201 (Functional)     (None, 1000)              20242984  
_________________________________________________________________
dense (Dense)                (None, 400)               400400    
_________________________________________________________________
batch_normalization (BatchNo (None, 400)               1600      
_________________________________________________________________
re_lu (ReLU)                 (None, 400)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 50)                20050     
=================================================================
Total params: 20,665,034
Trainable params: 4,490,090
Non-trainable params: 16,174,944

问题是我如何才能真正使 DenseNet 的前 161 层不可训练,而后 40 层可在加载的模型上训练?

【问题讨论】:

    标签: python machine-learning keras conv-neural-network keras-layer


    【解决方案1】:

    densenet201 (Functional) 是一个嵌套模型,因此您可以像访问“最顶层”模型的层一样访问其层。

    saved_model.layers[0].layers
    

    其中saved_model.layers[0] 是一个具有自己层的模型。

    在你的循环中,你需要像这样访问层

    saved_model.layers[0].layers[-cnt].trainable = True
    

    更新

    默认情况下,加载模型的层是可训练的 (trainable=True),因此您需要将底层的 trainable 属性设置为 False

    【讨论】:

    • 加载保存的模型后,layer.trainable是否默认为False
    • 不,它默认为True。您可以查看-print(saved_model.layers[0].layers[2].trainable)
    • 哦,刚刚意识到这一点。我在考虑如何访问这些层,而不是逻辑本身。
    猜你喜欢
    • 1970-01-01
    • 2019-01-20
    • 1970-01-01
    • 1970-01-01
    • 2017-08-19
    • 1970-01-01
    • 2020-05-12
    • 2018-07-31
    • 2018-04-15
    相关资源
    最近更新 更多