一、Tensorflow 模型文件格式转换
pb_convert.py文件如下:

import tensorflow as tf 
from tensorflow.python.platform import gfile 
from google.protobuf import text_format 
def convert_pb_to_pbtxt(filename): 
    with gfile.FastGFile(filename,'rb') as f: 
        graph_def = tf.GraphDef() 
        graph_def.ParseFromString(f.read()) 
        tf.import_graph_def(graph_def, name='') 
        tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True) 
        return 

def convert_pbtxt_to_pb(filename): 
    """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file. 
    Args: filename: The name of a file containing a GraphDef pbtxt (text-formatted `tf.GraphDef` protocol buffer data). """
    with tf.gfile.FastGFile(filename, 'r') as f: 
        graph_def = tf.GraphDef() 
        file_content = f.read() # Merges the human-readable string in `file_content` into `graph_def`. 
        text_format.Merge(file_content, graph_def) 
        tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False ) 
        return

调用方式如下:

import pb_convert
pb_convert.convert_pb_to_pbtxt('classify_image_graph_def.pb')

二、将模型.pb文件在tensorboard中展示结构

import tensorflow as tf
model = 'model.pb' #请将这里的pb文件路径改为自己的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

三、模型持久化
1.save模型
tensorflow从训练到使用
2.使用模型预测
tensorflow从训练到使用

四、使用HelloWorld训练和预测
训练与预测代码如下所示:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

import tensorflow as tf
sess = tf.InteractiveSession()

with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None, 784], name ='x_input')

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

#########################################
saver = tf.train.Saver()
training_flag = 0
if training_flag == 1:
    print ('training')
    tf.global_variables_initializer().run()
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        train_step.run({x: batch_xs, y_: batch_ys})
    saver.save(sess, "./cmodel.ckpt")
else:
    print ('prediction')
    saver.restore(sess, "./cmodel.ckpt")
#########################################

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
with tf.name_scope('accuracy'):
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

graph = tf.get_default_graph()
summaryWriter = tf.summary.FileWriter('log/', graph)
tf.train.write_graph( graph , './' , 'test.pb' , as_text = False ) 
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

四、基于MNIST的HelloWorld学习
1.tensorflow训练model

2.保存model,使用tensorboard查看graph
保存model:

graph = tf.get_default_graph()
summaryWriter = tf.summary.FileWriter('log/', graph)

使用tensorboard进行查看:

python ~/.local/lib/python3.5/site-packages/tensorboard/main.py --logdir=log 

3.加载model然后对测试用例进行predict
1)使用pb文件
加载

with tf.gfile.FastGFile(os.path.join(FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

预测

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

五、classify_image学习
1.运行例程
model里面自带的有经典的GoogleInceptionNet模型。因此可以直接运行tensorflow\models\tutorials\image\imagent下面的classify_image.py。该命令可以直接下载运行好的模型并识别对应的图片。其运行方式为:
进入Imagenet目录后运行classify_image.py脚本。具体为:
python classify_image.py --model_dir ~/Image --image_file ~/Image/a.jpg
(其中–model_dir表示模型将要下载的地址。 --image_file表示模型将要识别的图片)
结果为:
tensorflow从训练到使用
2.使用tf.train.server保存model

分类:

技术点:

相关文章: