【发布时间】:2017-08-24 08:47:23
【问题描述】:
我在 tensorflow 中构建了一个目标随机森林模型,并希望为 android 冻结和优化它。 我使用以下函数来构建 tesnor_forest 估计器:
def build_estimator(_model_dir, _num_classes, _num_features, _num_trees, _max_nodes):
params = tensor_forest.ForestHParams(
num_classes=_num_classes, num_features=_num_features,
num_trees=_num_trees, max_nodes=_max_nodes, min_split_samples=3)
graph_builder_class = tensor_forest.RandomForestGraphs
return random_forest.TensorForestEstimator(
params, graph_builder_class=graph_builder_class,
model_dir=_model_dir)
该函数将文本模型存储到指定模型目录下的graph.pbtxt文件中。
然后我使用以下方法训练它:
est = build_estimator(output_model_dir, 3,np.size(features_eval,1), 5,6)
train_X = features_eval.astype(dtype=np.float32)
train_Y = labels_y.astype(dtype=np.float32)
est.fit(x=train_X, y=train_Y, batch_size=np.size(features_eval,0))
(在这个简单的示例中:树数 = 5,max_nodes=6)
现在我想冻结模型,所以我调用了这个函数:
def save_model_android(model_path):
checkpoint_state_name = "model.ckpt-1"
input_graph_name = "graph.pbtxt"
output_graph_name = "freezed_model.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)
input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = None
input_binary = False
output_node_names = "output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = True
freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
在生成的 freezed_model.pb 文件中,我只得到 1 个操作,即输出节点。 在控制台中,当调用 freeze_graph 函数时,我收到以下消息:
Converted 0 variables to const ops.
1 ops in the final graph.
有谁知道为什么调用 freeze_graph 时只导出一个节点?
我正在使用支持 cuda 的 Tensorflow 1.2.1 版,从 linux 上的源代码安装
【问题讨论】:
标签: android tensorflow deep-learning tensor