【发布时间】:2017-02-19 13:15:38
【问题描述】:
我已经下载了一个名为 inception_resnet_v2_2016_08_30.ckpt 的 tensorflow 检查点模型。
我是否需要创建一个在创建此检查点时使用的图表(包含所有变量)?
如何使用这个模型?
【问题讨论】:
标签: python tensorflow
我已经下载了一个名为 inception_resnet_v2_2016_08_30.ckpt 的 tensorflow 检查点模型。
我是否需要创建一个在创建此检查点时使用的图表(包含所有变量)?
如何使用这个模型?
【问题讨论】:
标签: python tensorflow
我是否需要创建一个在创建此检查点时使用的图表(包含所有变量)?
不,你没有。
至于如何使用检查点文件(cpkt文件)
1.这篇文章 (TensorFlow-Slim image classification library) 告诉你如何从头开始训练你的模型
2.以下是google blog的示例代码
import numpy as np
import os
import tensorflow as tf
import urllib2
from datasets import imagenet
from nets import inception
from preprocessing import inception_preprocessing
slim = tf.contrib.slim
batch_size = 3
image_size = inception.inception_v3.default_image_size
checkpoints_dir = '/root/code/model'
checkpoints_filename = 'inception_resnet_v2_2016_08_30.ckpt'
model_name = 'InceptionResnetV2'
sess = tf.InteractiveSession()
graph = tf.Graph()
graph.as_default()
def classify_from_url(url):
image_string = urllib2.urlopen(url).read()
image = tf.image.decode_jpeg(image_string, channels=3)
processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
logits, _ = inception.inception_resnet_v2(processed_images, num_classes=1001, is_training=False)
probabilities = tf.nn.softmax(logits)
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, checkpoints_filename),
slim.get_model_variables(model_name))
init_fn(sess)
np_image, probabilities = sess.run([image, probabilities])
probabilities = probabilities[0, 0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
plt.figure()
plt.imshow(np_image.astype(np.uint8))
plt.axis('off')
plt.show()
names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
index = sorted_inds[i]
print('Probability %0.2f%% => [%s]' % (probabilities[index], names[index]))
【讨论】:
首先,您已经了解了内存中的网络架构。网络架构可以从here获取
一旦你有了这个程序,使用以下方法来使用模型:
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
height = 299
width = 299
channels = 3
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)
这样你就拥有了内存中的所有网络,现在你可以使用 tf.train.saver 使用检查点文件(ckpt)初始化网络:
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")
如果你想做瓶特征提取,它很简单,比如你想从最后一层获取特征,那么你只需声明predictions = end_points["Logits"] 如果你想为其他中间层获取它,你可以得到那些来自上述程序 inception_resnet_v2.py 的名称
之后您可以拨打:output = sess.run(predictions, feed_dict={X:batch_images})
【讨论】:
ValueError: No op named SSTableReaderV2 in defined operations. 你能告诉我我们是如何工作的吗围绕这个?
另一种加载预训练 Imagenet 模型的方法是
ResNet50
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50()
model.summary()
InceptionV3
iport tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3
model = InceptionV3()
model.summary()
您可以查看与此here相关的详细说明
【讨论】: