【问题标题】:Keras.model.summary does not correctly display my model..?Keras.model.summary 没有正确显示我的模型..?
【发布时间】:2020-05-05 01:55:15
【问题描述】:

我想通过 keras.model.summary 查看我的模型的摘要,但是效果不好。 我的代码如下:

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Faltten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = trythis(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((None, 32,32,3))
model.summary()

我期待 BatchNorm 层,但总结如下:

Model: "my_model_30"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_31 (Conv2D)           multiple                  896       
_________________________________________________________________
flatten_30 (Flatten)         multiple                  0         
_________________________________________________________________
dense_60 (Dense)             multiple                  3686528   
_________________________________________________________________
dense_61 (Dense)             multiple                  1290      
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0

'trythis' 方法中不包含 BatchNorm 层。

我该如何解决这个问题?

感谢您的阅读。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    子类模型的形状推断不像功能 API 中那样自动。所以我在子类模型中添加了一个模型调用,并定义了一个功能模型,如下所示。请注意,有几种方法可以做,我展示的是一种方法。请在我回答的类似问题中查看更多详细信息here

    import tensorflow as tf
    from tensorflow import keras
    
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization
    
    class MyModel(Model):
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = Conv2D(32,3,activation = 'relu')
            self.flatten = Flatten()
            self.d1 = Dense(128, activation = 'relu')
            self.d2 = Dense(10, activation = 'relu')
    
        def trythis(self,x):
            a = BatchNormalization()
            b = a(x)
            return b
    
        def call(self, x):
            x = self.conv1(x)
            x = MyModel.trythis(self,x)
            x = self.flatten(x)
            x = self.d1(x)
            return self.d2(x)
        def model(self):
            x = tf.keras.layers.Input(shape=(32, 32, 3))
            return Model(inputs=[x], outputs=self.call(x))
    
    model = MyModel()
    model_functional = model.model()
    #model.build((None, 32,32,3))
    model_functional.summary()
    

    总结如下

    Model: "model"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
    _________________________________________________________________
    conv2d_5 (Conv2D)            (None, 30, 30, 32)        896       
    _________________________________________________________________
    batch_normalization (BatchNo (None, 30, 30, 32)        128       
    _________________________________________________________________
    flatten_4 (Flatten)          (None, 28800)             0         
    _________________________________________________________________
    dense_8 (Dense)              (None, 128)               3686528   
    _________________________________________________________________
    dense_9 (Dense)              (None, 10)                1290      
    =================================================================
    Total params: 3,688,842
    Trainable params: 3,688,778
    Non-trainable params: 64
    _________________________________________________________________
    

    【讨论】:

    • 啊哈.. 非常感谢!我有另一个问题。在我的原始代码中,虽然 keras.model.summary 没有正确捕获模型的架构,但模型可以按我的预期工作吗?
    • 这是意料之中的。唯一的问题是形状推断并不容易,因为它们是静态图的功能和顺序模型。这种推断仅在可视化为 model.summary 或绘制模型等时需要,但子类模型按预期工作。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2010-12-09
    • 1970-01-01
    • 1970-01-01
    • 2018-01-02
    • 2017-02-25
    • 1970-01-01
    • 2021-07-13
    相关资源
    最近更新 更多