【发布时间】:2021-01-31 04:55:17
【问题描述】:
我有以下代码用于将 DNN 模型分成两部分。
def split(model, input):
starting_layer_name = input
new_input = layers.Input( batch_shape=model.get_layer(starting_layer_name).get_input_shape_at(0))
layer_outputs = {}
def get_output_of_layer(layer):
if layer.name in layer_outputs:
return layer_outputs[layer.name]
if layer.name == starting_layer_name:
out = layer( new_input )
layer_outputs[layer.name] = out
return out
prev_layers = []
for node in layer._inbound_nodes:
prev_layers.extend( node.inbound_layers )
pl_outs = []
for pl in prev_layers:
pl_outs.extend( [get_output_of_layer( pl )] )
out = layer( pl_outs[0] if len( pl_outs ) == 1 else pl_outs )
layer_outputs[layer.name] = out
return out
if starting_layer_name=='input_1':
new_output = get_output_of_layer(model.layers[-21])
block_1 = models.Model( new_input, new_output )
return block_1
elif starting_layer_name=='block1_pool':
new_output =get_output_of_layer((model.layers[-1]))
block_2 = models.Model(new_input, new_output)
return block_2
block_1=split(model,'input_1')
block_2=split(model,'block1_pool')
block_1.save('my_model1.h5')
block_2.save('my_model2.h5')
当我尝试运行以下内容时,我检索到“图表已断开连接无法获取张量的值。
from Keras.models import load_model
model = load_model('my_model1.h5')
model.summary()
高度赞赏帮助解决此问题。我试图拆分模型的当前方法给了我一个错误,是否有另一种方法可以在 keras 中解决这个问题。
【问题讨论】:
标签: python tensorflow keras tensorflow2.0 tf.keras