【问题标题】:How do I plot a Keras/Tensorflow subclassing API model?如何绘制 Keras/Tensorflow 子类化 API 模型?
【发布时间】:2020-08-09 03:51:21
【问题描述】:

我使用 Keras 子类化 API 制作了一个可以正确运行的模型。 model.summary() 也可以正常工作。当尝试使用tf.keras.utils.plot_model() 来可视化我的模型架构时,它只会输出这个图像:

这几乎感觉像是来自 Keras 开发团队的一个笑话。这是完整的架构:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x


skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')

【问题讨论】:

    标签: python tensorflow plot keras deep-learning


    【解决方案1】:

    更新(2021 年 1 月 4 日):这似乎是可能的;见@M.Innat 的answer


    无法做到这一点,因为与使用功能/顺序 API(在 TF 术语中称为图网络)创建的模型相比,在 TensorFlow 中实现的模型子分类基本上在特性和功能上受到限制。如果您查看plot_model 源代码,您会在model_to_dot 函数(由plot_model 调用)中看到the following check

    if not model._is_graph_network:
      node = pydot.Node(str(id(model)), label=model.name)
      dot.add_node(node)
      return dot
    

    正如我所提到的,子分类模型不是图形网络,因此只会为这些模型绘制包含模型名称的节点(即您观察到的相同事物)。

    这已经在Github issue 中讨论过了,TensorFlow 的开发者之一通过给出以下论点证实了这种行为:

    @omalleyt12 评论:

    是的,一般来说,我们不能假设任何关于子类模型的结构。如果您的模型可以看作是层块并且您希望将其可视化,我们建议您查看功能 API

    【讨论】:

      【解决方案2】:

      我找到了一些使用模型子类 API 进行绘图的解决方法。很明显,Sub-Classing API 不支持 Sequential 或 Functional API,如 model.summary() 和使用 plot_model 的漂亮可视化。在这里,我将演示两者。

      class my_model(Model):
          def __init__(self, dim):
              super(my_model, self).__init__()
              self.Base  = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
              self.GAP   = L.GlobalAveragePooling2D()
              self.BAT   = L.BatchNormalization()
              self.DROP  = L.Dropout(rate=0.1)
              self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
              self.OUT   = L.Dense(1, activation='sigmoid')
          
          def call(self, inputs):
              x  = self.Base(inputs)
              g  = self.GAP(x)
              b  = self.BAT(g)
              d  = self.DROP(b)
              d  = self.DENS(d)
              return self.OUT(d)
          
          # AFAIK: The most convenient method to print model.summary() 
          # similar to the sequential or functional API like.
          def build_graph(self):
              x = Input(shape=(dim))
              return Model(inputs=[x], outputs=self.call(x))
      
      dim = (124,124,3)
      model = my_model((dim))
      model.build((None, *dim))
      model.build_graph().summary()
      

      它将产生如下:

      Layer (type)                 Output Shape              Param #   
      =================================================================
      input_67 (InputLayer)        [(None, 124, 124, 3)]     0         
      _________________________________________________________________
      vgg16 (Functional)           (None, 3, 3, 512)         14714688  
      _________________________________________________________________
      global_average_pooling2d_32  (None, 512)               0         
      _________________________________________________________________
      batch_normalization_7 (Batch (None, 512)               2048      
      _________________________________________________________________
      dropout_5 (Dropout)          (None, 512)               0         
      _________________________________________________________________
      dense_A (Dense)              (None, 256)               402192    
      _________________________________________________________________
      dense_7 (Dense)              (None, 1)                 785       
      =================================================================
      Total params: 14,848,321
      Trainable params: 14,847,297
      Non-trainable params: 1,024
      

      现在通过使用build_graph 函数,我们可以简单地绘制整个架构。

      # Just showing all possible argument for newcomer.  
      tf.keras.utils.plot_model(
          model.build_graph(),                      # here is the trick (for now)
          to_file='model.png', dpi=96,              # saving  
          show_shapes=True, show_layer_names=True,  # show shapes and layer name
          expand_nested=False                       # will show nested block
      )
      

      它将产生如下:-)

      【讨论】:

      • 我喜欢这个解决方法。但是,它仅适用于简单模型。一旦一个模型被另一个模型包围,嵌套将不会被解析(即生成器和判别器实现为keras.Model 的 GAN,参见here)。设置 expand_nested=True 不会改变行为。有什么建议吗?
      • 不确定。但如果可能,请分享一些玩具代码以供探索。
      • 我用玩具例子问了一个新问题here
      【解决方案3】:

      另一种解决方法:使用tf2onnx将savemodel格式模型转换为onnx,然后使用netron查看模型架构。

      这是 netron 中模型的一部分:

      【讨论】:

        猜你喜欢
        • 2019-08-30
        • 2019-12-16
        • 1970-01-01
        • 2021-03-11
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多