我通过按图中相似的节点映射节点来解决这个问题。然后通过tf.import_graph_def 连接它并通过graph_transform 删除unused_nodes。对于量化能力,避免使用merge duplicate或fold batch norm,这会因缺少min-max quant而造成量化误差
import tensorflow as tf
import numpy as np
# load graphs using pb file path
def load_graph(pb_file):
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(pb_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return graph
resnet_pretrained = 'frozen_124.pb'
trained = 'frozen.pb'
# new file name to save combined model
final_graph = 'final_graph.pb'
# loads both graphs
graph1 = load_graph(resnet_pretrained)
graph2 = load_graph(trained)
replace_dict = {}
# get tensor names from first graph
with graph1.as_default():
# getting tensors to add crop and resize step
ops = graph1.get_operations()
ops1_name = []
for op in ops:
# print(op.name)
ops1_name.append(op.name)
ops = graph2.get_operations()
ops2_name = []
replace_name = []
for op in ops:
# print(op.name)
ops2_name.append(op.name)
if op.name in ops1_name:
replace_name = op.name
replace_dict[str(replace_name)+':0'] = replace_name+':0'
continue
if 'resnet' in op.name:
replace_name = op.name.replace("resnet","model")
if replace_name in ops1_name:
replace_dict[str(op.name)+':0'] = replace_name+':0'
with tf.Graph().as_default() as final:
y = tf.import_graph_def(graph1.as_graph_def(), return_elements=replace_dict.values())
new=dict()
for i,j in zip(replace_dict.keys(),y):
new[i] = j
z = tf.import_graph_def(graph2.as_graph_def(), input_map=new, return_elements=["concatenate_1/concat:0"])
# tf.train.write_graph(graph2.as_graph_def(), "./", final_graph, as_text=False)
# for op in final.get_operations():
# print(op.name)
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ['remove_nodes(op=Identity)',
'strip_unused_nodes']
output_graph_def = TransformGraph(
final.as_graph_def(),
["import/input_image","import_1/input_box"], ## input
["import_1/concatenate_1/concat"], ## outputs
transforms)
tf.train.write_graph(output_graph_def, '.' , as_text=False, name='optimized_model.pb')
print('Graph optimized!')