【发布时间】:2017-08-04 02:44:11
【问题描述】:
我知道这是一个常见错误,但我无法理解这个问题。这是我的代码:
def convert_image(url):
checkpoint_file = './vgg_16.ckpt'
input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
#Load the model
sess = tf.Session()
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
response = requests.get(url)
img = Image.open(BytesIO(response.content))
im = np.array(img, dtype='float32')
im = im.reshape(-1,224,224,3)
features = sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im})
sess.close()
return np.squeeze(features)
如您所见,我使用 VGG_16 预训练模型来提取 fc7 特征。大约 50% 的代码只是简单地从 URL 中获取图像并将其转换为 224x224x3;另外 50% 的 tensorflow 工作是为了实际获得特征表示。
问题是,我第一次运行此代码时它运行良好。但是,第二次,我收到上述错误。当然,“im”是一个 float32,即使我遇到了这个错误。所以我认为这个问题与我第二次运行这个函数时出现的问题有关。如果我不得不猜测,它与“保护程序”的工作方式有关,但我无法弄清楚究竟是什么。
有什么想法吗?
【问题讨论】:
标签: python tensorflow