【问题标题】:Check which are the next layers in a tensorflow keras model检查张量流 keras 模型中的下一层
【发布时间】:2021-09-08 14:38:49
【问题描述】:

我有一个 模型,它在层之间有快捷方式。对于每一层,我想获取下一个连接层的名称(或索引),因为简单地遍历所有 model.layers 不会告诉我该层是否连接到前一层。

一个示例模型可以是:

model = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=None, pooling=None, classes=1000)

【问题讨论】:

    标签: keras tensorflow keras tensorflow2.0


    【解决方案1】:

    可以通过这种方式提取dict格式的信息...

    首先,定义一个效用函数,并从每个Functional 模型(code reference)中获取model.summary() 方法中制作的相关节点(code reference

    relevant_nodes = []
    for v in model._nodes_by_depth.values():
        relevant_nodes += v
    
    def get_layer_summary_with_connections(layer):
        
        info = {}
        connections = []
        for node in layer._inbound_nodes:
            if relevant_nodes and node not in relevant_nodes:
                # node is not part of the current network
                continue
    
            for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
                connections.append(inbound_layer.name)
                
        name = layer.name
        info['type'] = layer.__class__.__name__
        info['parents'] = connections
                
        return info
    

    其次,通过层迭代提取信息:

    results = {}
    layers = model.layers
    for layer in layers:
        info = get_layer_summary_with_connections(layer)
        results[layer.name] = info
    

    results 是一个嵌套的dict,格式如下:

    {
      'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'},
      ...
      'layer_name': {'type':'the layer type', 'parents':'list of the parent layers'}
    }
    

    对于ResNet50,结果为:

    {
      'input_4': {'type': 'InputLayer', 'parents': []},
      'conv1_pad': {'type': 'ZeroPadding2D', 'parents': ['input_4']},
      'conv1_conv': {'type': 'Conv2D', 'parents': ['conv1_pad']},
      'conv1_bn': {'type': 'BatchNormalization', 'parents': ['conv1_conv']},
      ...
      'conv5_block3_out': {'type': 'Activation', 'parents': ['conv5_block3_add']},
      'avg_pool': {'type': 'GlobalAveragePooling2D', 'parents' ['conv5_block3_out']},
      'predictions': {'type': 'Dense', 'parents': ['avg_pool']}
    }
    

    另外,您可以修改get_layer_summary_with_connections,返回您感兴趣的所有信息

    【讨论】:

      【解决方案2】:

      您可以查看整个模型及其与keras的Model plotting utilities的连接

      tf.keras.utils.plot_model(model, to_file='path/to/image', show_shapes=True)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2021-01-29
        • 1970-01-01
        • 2019-05-07
        • 2018-10-11
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多