【问题标题】:retraining last layer of inception-v3 significantly slowers the classification重新训练最后一层 inception-v3 会显着减慢分类速度
【发布时间】:2018-07-31 18:48:35
【问题描述】:

在尝试使用 TF 和 PY3.5 对 inception-v3 进行迁移学习时,我测试了两种方法:

1- 重新训练最后一层,如下所示:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/image_retraining

2- 在 inception-V3 瓶颈之上应用线性 SVM,如下所示:https://www.kernix.com/blog/image-classification-with-a-pre-trained-deep-neural-network_p11

不出所料,它们在分类阶段应该有类似的运行时间,因为关键部分 - 瓶颈提取 - 是相同的。但在实践中,经过再训练的网络在运行分类时会慢 8 倍左右。

我的问题是是否有人对此有想法。

一些代码sn-ps:

SVM 位于顶部(速度越快):

def getTensors():
    graph_def = tf.GraphDef()
    f = open('classify_image_graph_def.pb', 'rb')
    graph_def.ParseFromString(f.read())
    tensorBottleneck, tensorsResizedImage = tf.import_graph_def(graph_def, name='', return_elements=['pool_3/_reshape:0', 'Mul:0'])
    return tensorBottleneck, tensorsResizedImage 

def calc_bottlenecks(imgFile, tensorBottleneck, tensorsResizedImage):
    """ - read, decode and resize to get <resizedImage> - """
    bottleneckValues = sess.run(tensorBottleneck, {tensorsResizedImage : resizedImage})
    return np.squeeze(bottleneckValues)

这在我的 (Windows) 笔记本电脑上大约需要 0.5 秒,而 SVM 部分不需要时间。

重新训练最后一层 - (由于代码较长,这更难总结)

def loadGraph(pbFile):
    with tf.gfile.FastGFile(pbFile, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    with tf.Session() as sess:
        softmaxTensor = sess.graph.get_tensor_by_name('final_result:0')

def labelImage(imageFile, softmaxTensor):
    with tf.Session() as sess:
        input_layer_name = 'DecodeJpeg/contents:0'
        predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})

'pbFile'是retrainer保存的文件,除了分类层之外,它应该具有相同的拓扑和权重,如'classify_image_graph_def.pb'。运行大约需要 4 秒(在我的同一台笔记本电脑上,没有加载)。

关于性能差距的任何想法? 谢谢!

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    解决了。问题在于为每个图像创建一个新的 tf.Session() 。在读取图形时存储会话并使用它使运行时间回到预期。

    def loadGraph(pbFile):
        ...
        with tf.Session() as sess:
            softmaxTensor = sess.graph.get_tensor_by_name('final_result:0')
            sessToStore = sess
        return softmaxTensor, sessToStore  
    
    def labelImage(imageFile, softmaxTensor, sessToStore):
        input_layer_name = 'DecodeJpeg/contents:0'
        predictions, = sessToStore.run(softmax_tensor, {input_layer_name: image_data})
    

    【讨论】:

      猜你喜欢
      • 2017-05-15
      • 2022-01-27
      • 1970-01-01
      • 2019-06-11
      • 2020-03-24
      • 2017-03-28
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多