【问题标题】:Tensorflow subclass keras with mulitple output shape具有多个输出形状的 TensorFlow 子类 keras
【发布时间】:2021-04-03 04:12:06
【问题描述】:
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(1)
        self.build(input_shape=[None, 1])

    def call(self, inputs, **kwargs):
        return self.dense(inputs)

MyModel().summary()

模型图也不行:

tf.keras.utils.plot_model(model, to_file='model_1.png', show_shapes=True)

我在几个 tensorflow 版本 2.3.0、2.3.1 和 2.4.1 上尝试了这段代码,每次output shape 都是multiple!这是一个错误吗?有什么办法吗?

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    不是错误。一般来说,我们不能假设任何关于子类模型的结构。这就是为什么您无法在与 FunctionalSequential API 类似的模型 Subclasses API 中的 .summary() 中获得输出形状。

    但这里有一个解决方法来实现这一点。您可以通过以下方法实现。

    import tensorflow as tf 
    
    class MyModel(tf.keras.Model):
        def __init__(self):
            super().__init__()
            self.dense = tf.keras.layers.Dense(1)
            self.build(input_shape=[None, 1])
    
        def call(self, inputs, **kwargs):
            return self.dense(inputs)
    
        def build_graph(self):
            x = tf.keras.layers.Input(shape=(1))
            return tf.keras.Model(inputs=[x], outputs=self.call(x))
    
    MyModel().build_graph().summary()
    
    Model: "model"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_2 (InputLayer)         [(None, 1)]               0         
    _________________________________________________________________
    dense_3 (Dense)              (None, 1)                 2         
    =================================================================
    Total params: 2
    Trainable params: 2
    Non-trainable params: 0
    _________________________________________________________________
    

    与绘制模型相同。

    tf.keras.utils.plot_model(
        MyModel().build_graph()                     
    )
    

    【讨论】:

    • 我注意到即使在 tensorboard 中,图表也没有显示所有层,特别是 BatchNormalization 层。
    • 这与此无关,而是另一个问题。导入 bn 后试试这个BatchNormalization._USE_V2_BEHAVIOR = False
    • AFAIK,它发生在tf 2.0, 2.1,但后来在2.3 中修复>=。
    • 我通过安装tf 2.4.1修复了它
    猜你喜欢
    • 2023-02-09
    • 1970-01-01
    • 1970-01-01
    • 2018-08-23
    • 1970-01-01
    • 2018-03-09
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多