【问题标题】:How to freeze a sub-model for one model, without affecting other models?如何冻结一个模型的子模型,而不影响其他模型?
【发布时间】:2020-04-17 23:59:07
【问题描述】:

我正在尝试制作像 GAN 这样的模型。但我不知道如何为一个模型正确设置 trainable 为 False。似乎所有使用子模型的模型都会受到影响。

代码:

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense

print(tf.__version__)

def build_submodel():
  inp = tf.keras.Input(shape=(3,))
  x = Dense(5)(inp)
  model = Model(inputs=inp, outputs=x)
  return model

def build_model_A():
  inp = tf.keras.Input(shape=(3,))
  x = submodel(inp)
  x = Dense(7)(x)
  model = Model(inputs=inp, outputs=x)
  return model

def build_model_B():
  inp = tf.keras.Input(shape=(11,))
  x = Dense(3)(inp)
  x = submodel(x)
  model = Model(inputs=inp, outputs=x)
  return model

submodel = build_submodel()
model_A = build_model_A()
model_A.compile("adam", "mse")
model_A.summary()
submodel.trainable = False
# same result with freezing layers
# for layer in submodel.layers:
#   layer.trainable = True
model_B = build_model_B()
model_B.compile("adam", "mse")
model_B.summary()

model_A.summary()

输出:

Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 3)]               0         
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
_________________________________________________________________
dense_10 (Dense)             (None, 7)                 42        
=================================================================
Total params: 62
Trainable params: 62
Non-trainable params: 0
_________________________________________________________________
Model: "model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_12 (InputLayer)        [(None, 11)]              0         
_________________________________________________________________
dense_11 (Dense)             (None, 3)                 36        
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
=================================================================
Total params: 56
Trainable params: 36
Non-trainable params: 20
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 3)]               0         
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
_________________________________________________________________
dense_10 (Dense)             (None, 7)                 42        
=================================================================
Total params: 62
Trainable params: 42
Non-trainable params: 20
_________________________________________________________________

起初,model_A 没有不可训练的权重。但是在建立model_B之后。 model_A 有一些不可训练的权重。

此外,摘要没有显示哪些层是不可训练的,只是总的不可训练参数计数。有没有更好的方法来检查模型中哪些层被冻结?

【问题讨论】:

    标签: tensorflow keras


    【解决方案1】:

    您可以使用此功能显示哪个层可训练或不可训练

    def print_params(model):
    
      def count_params(weights):
          """Count the total number of scalars composing the weights.
          # Arguments
              weights: An iterable containing the weights on which to compute params
          # Returns
              The total number of scalars composing the weights
          """
          weight_ids = set()
          total = 0
          for w in weights:
              if id(w) not in weight_ids:
                  weight_ids.add(id(w))
                  total += int(K.count_params(w))
          return total
    
      trainable_count = count_params(model.trainable_weights)
      non_trainable_count = count_params(model.non_trainable_weights)
    
      print('id\ttrainable : layer name')
      print('-------------------------------')
      for i, layer in enumerate(model.layers):
          print(i,'\t',layer.trainable,'\t  :',layer.name)
      print('-------------------------------')
    
      print('Total params: {:,}'.format(trainable_count + non_trainable_count))
      print('Trainable params: {:,}'.format(trainable_count))
      print('Non-trainable params: {:,}'.format(non_trainable_count))
    

    会这样输出

    id  trainable : layer name
    -------------------------------
    0    False    : input_1
    1    False    : block1_conv1
    2    False    : block1_conv2
    3    False    : block1_pool
    4    False    : block2_conv1
    5    False    : block2_conv2
    6    False    : block2_pool
    7    False    : block3_conv1
    8    False    : block3_conv2
    9    False    : block3_conv3
    10   False    : block3_pool
    11   False    : block4_conv1
    12   False    : block4_conv2
    13   False    : block4_conv3
    14   False    : block4_pool
    15   False    : block5_conv1
    16   False    : block5_conv2
    17   False    : block5_conv3
    18   False    : block5_pool
    19   True     : global_average_pooling2d
    20   True     : dense
    21   True     : dense_1
    22   True     : dense_2
    -------------------------------
    Total params: 15,245,130
    Trainable params: 530,442
    Non-trainable params: 14,714,688
    

    【讨论】:

    • 您的代码打印出layer.trainable。根据thisthisthisthis,在model.compile 之后更改trainable 无效。但是我上面的代码与这个假设相矛盾,这就是我问这个问题的原因。
    猜你喜欢
    • 2020-11-09
    • 2013-11-24
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-10-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多