1.介绍

目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度

如下图检测人和自行车这两个目标,检测结果包括目标的位置、目标的类别和置信度

mxnet深度学习实战学习笔记-9-目标检测

因为目标检测算法需要输出目标的类别和具体坐标,因此在数据标签上不仅要有目标的类别,还要有目标的坐标信息

 

可见目标检测比图像分类算法更复杂。图像分类算法只租要判断图像中是否存在指定目标,不需要给出目标的具体位置;而目标检测算法不仅需要判断图像中是否存在指定类别的目标,还要给出目标的具体位置

因此目标检测算法实际上是多任务算法,一个任务是目标的分类,一个任务是目标位置的确定;二图像分类算法是单任务算法,只有一个分类任务

 

2.数据集

目前常用的目标检测公开数据集是PASCAL VOC(http://host.robots.ox.ac.uk/pascal/VOC) 和COCO(http://cocodataset.org/#home)数据集,PASCAL VOC常用的是PASCAL VOC2007和PASCAL VOC2012两个数据集,COCO常用的是COCO2014和COCO2017两个数据集

在评价指标中,目标检测算法和常见的图像分类算法不同,目标检测算法常用mAP(mean average precision)作为评价模型效果的指标。mAP值和设定的IoU(intersection-over-union)阈值相关,不同的IoU阈值会得到不同的mAP。目前在PASCAL VOC数据集上常用IoU=0.5的阈值;在COCO数据集中IoU阈值选择较多,常在IoU=0.50:0.05:0.95这10个IoU阈值上分别计算AP,然后求均值作为最终的mAP结果

另外在COCO数据集中还有针对目标尺寸而定义的mAP计算方式,可以参考COCO官方网站(http://cocodataset.org/#detection-eval)中对评价指标的介绍

 

目标检测算法在实际中的应用非常广泛,比如基于通用的目标检测算法,做车辆、行人、建筑物、生活物品等检测。在该基础上,针对一些特定任务或者场景,往往衍生出特定的目标检测算法,如人脸检测和文本检测

人脸检测:即目前的刷脸,如刷脸支付和刷脸解锁。包括人脸检测和人脸识别这两个主要步骤。人脸检测就是先从输入图像中检测到人脸所在区域,然后将检测到的人脸作为识别算法的输入得到分类结果

文本检测:文字检测作为光学字符识别(optical character recognition,OCR)的重要步骤,主要目的在于从输入图像中检测出文字所在的区域,然后作为文字识别器的输入进行识别

 

目标检测算法可分为两种类型:one-stage和two-stage,两者的区别在于前者是直接基于网络提取到的特征和预定义的框(anchor)进行目标预测;后者是先通过网络提取到的特征和预定义的框学习得到候选框(region of interest,RoI),然后基于候选框的特征进行目标检测

  • one-stage:代表是SSD(sigle shot detection)和YOLO(you only look once)等
  • two-stage:代表是Faster-RCNN 等

两者的差异主要有两方面:

一方面是one-stage算法对目标框的预测只进行一次,而two-stage算法对目标框的预测有两次,类似从粗到细的过程

另一方面one-stage算法的预测是基于整个特征图进行的,而two-stage算法的预测是基于RoI特征进行的。这个RoI特征就是初步预测得到框(RoI)在整个特征图上的特征,也就是从整个特征图上裁剪出RoI区域得到RoI特征

 

3.目标检测基础知识

 目标检测算法在网络结构方面和图像分类算法稍有不同,网络的主干部分基本上采用图像分类算法的特征提取网络,但是在网络的输出部分一般有两条支路,一条支路用来做目标分类,这部分和图像分类算法没有什么太大差异,也是通过交叉熵损失函数来计算损失;另一条支路用来做目标位置的回归,这部分通过Smooth L1损失函数计算损失。因此整个网络在训练时的损失函数是由分类的损失函数和回归的损失函数共同组成,网络参数的更新都是基于总的损失进行计算的,因此目标检测算法是多任务算法

1)one-stage

SSD算法首先基于特征提取网络提取特征,然后基于多个特征层设置不同大小和宽高比的anchor,最后基于多个特征层预测目标类别和位置,本章将使用的就是这个算法。

SSD算法在效果和速度上取得了非常好的平衡,但是在检测小尺寸目标上效果稍差,因此后续的优化算法,如DSSD、RefineDet等,主要就是针对小尺寸目标检测进行优化

YOLO算法的YOLO v1版本中还未引入anchor的思想,整体也是基于整个特征图直接进行预测。YOLO v2版本中算法做了许多优化,引入了anchor,有效提升了检测效果;通过对数据的目标尺寸做聚类分析得到大多数目标的尺寸信息从而初始化anchor。YOLO v3版本主要针对小尺寸目标的检测做了优化,同时目标的分类支路采用了多标签分类

2)two-stage

由RCNN 算法发展到Fast RCNN,主要引入了RoI Pooling操作提取RoI特征;再进一步发展到Faster RCNN,主要引入RPN网络生成RoI,从整个优化过程来看,不仅是速度提升明显,而且效果非常棒,目前应用广泛,是目前大多数two-stage类型目标检测算法的优化基础

Faster RCNN系列算法的优化算法非常多,比如R-FCN、FPN等。R-FCN主要是通过引入区域敏感(position  sensitive)的RoI Pooling减少了Faster RCNN算法中的重复计算,因此提速十分明显。FPN算法主要通过引入特征融合操作并基于多个融合后的特征成进行预测,有效提高了模型对小尺寸目标的检测效果

 

虽然算法分为上面的两种类别,但是在整体流程上,主要可以分为三大部分:

  • 主网络部分:主要用来提取特征,常称为backbone,一般采用图像分类算法的网络即可,比如常用的VGG和ResNet网络,目前也有在研究专门针对于目标检测任务的特征提取网络,比如DetNet
  • 预测部分:包含两个支路——目标类别的分类支路和目标位置的回归支路。预测部分的输入特征经历了从单层特征到多层特征,从多层特征到多层融合特征的过程,算法效果也得到了稳定的提升。其中Faster RCNN算法是基于单层特征进行预测的例子,SSD算法是基于多层特征进行预测的例子,FRN算法是基于多层融合特征进行预测的例子
  • NMS操作:(non maximum suppression,非极大值抑制)是目前目标检测算法常用的后处理操作,目的是去掉重复的预测框

 

3)准备数据集

VOL2007数据集包括9963张图像,其中训练验证集(trainval)有5011张图像(2G),测试集(test)有4952张

VOL2012数据集包含17125张图像,其中训练验证集(trainval)有11540张图像(450M),测试集(test)有5585张

首先使用命令下载数据:

wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar

下载完后会在当前目录下看见名为VOCtrainval_11-May-2012.tar、VOCtrainval_06-Nov-2007.tar和VOCtest_06-Nov-2007.tar这三个压缩包,然后运行下面的命令进行解压缩:

tar -xvf VOCtrainval_06-Nov-2007.tar
tar -xvf VOCtest_06-Nov-2007.tar 
tar -xvf VOCtrainval_11-May-2012.tar

然后在当前目录下就会出现一个名为VOCdevkit的文件下,里面有两个文件为VOL2007和VOL2012,这里以VOL2007为例,可见有5个文件夹:

user@home:/opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007$ ls
Annotations  ImageSets  JPEGImages  SegmentationClass  SegmentationObject

介绍这5个文件夹:

  • Annotations:存放的是关于图像标注信息文件(后缀为.xml)
  • ImageSets:存放的是训练和测试的图像列表信息
  • JPEGImages:存放的是图像文件
  • SegmentationClass: 存放的是和图像分割相关的数据,这一章暂不讨论,下章再说
  • SegmentationObject:存放的是和图像分割相关的数据,这一章暂不讨论,下章再说

ImageSets文件中有下面的四个文件夹:

  • Action:存储人的动作
  • Layout:存储人的部位
  • Main:存储检测索引
  • Segmentation :存储分割

其中Main中,每个类都有对应的classname_train.txt、classname_val.txt和classname_trainval.txt三个索引文件,分别对应训练集,验证集和训练验证集(即训练集+验证集)。

另外还有一个train.txt(5717)、val.txt(5823)和trainval.txt(11540)为所有类别的一个索引。

 

 Annotations包含于图像数量相等的标签文件(后缀为.xml),ls命令查看,有000001.xml到009963.xml这9963个文件

查看其中的000001.xml文件:

user@home:/opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007/Annotations$ cat 000001.xml
<annotation>
        <folder>VOC2007</folder> <!--数据集名称 -->
        <filename>000001.jpg</filename> <!--图像名称 -->
        <source>
                <database>The VOC2007 Database</database>
                <annotation>PASCAL VOC2007</annotation>
                <image>flickr</image>
                <flickrid>341012865</flickrid>
        </source>
        <owner>
                <flickrid>Fried Camels</flickrid>
                <name>Jinky the Fruit Bat</name>
        </owner>
        <size> <!--图像长宽信息 -->
                <width>353</width>
                <height>500</height>
                <depth>3</depth>
        </size>
        <segmented>0</segmented>
        <object> <!--两个目标的标注信息 -->
                <name>dog</name> <!--目标的类别名,类别名以字符结尾,该类别为dog -->
                <pose>Left</pose>
                <truncated>1</truncated>
                <difficult>0</difficult>
                <bndbox> <!--目标的坐标信息,以字符结尾,包含4个坐标标注信息,且标注框都是矩形框 -->
                        <xmin>48</xmin> <!--矩形框左上角点横坐标 -->
                        <ymin>240</ymin> <!--矩形框左上角点纵坐标 -->
                        <xmax>195</xmax> <!--矩形框右下角点横坐标 -->
                        <ymax>371</ymax> <!--矩形框右下角点纵坐标 -->
                </bndbox>
        </object>
        <object>
                <name>person</name> <!--目标的类别名,类别名以字符结尾,该类别为person -->
                <pose>Left</pose>
                <truncated>1</truncated>
                <difficult>0</difficult>
                <bndbox>
                        <xmin>8</xmin>
                        <ymin>12</ymin>
                        <xmax>352</xmax>
                        <ymax>498</ymax>
                </bndbox>
        </object>
</annotation>

 

初了查看标签文件之外,还可以通过可视化方式查看这些真实框的信息,下面代码根据VOC数据集的一张图像和标注信息得到带有真实框标注的图像

运行时出现一个问题:

_tkinter.TclError: no display name and no $DISPLAY environment variable

原因是我们不是在Windows下跑的,是在Linux下跑的,不同的系统有不同的用户图形接口,所以要更改它的默认配置,把模式更改成Agg。即在代码最上面添加一行代码:

import matplotlib
matplotlib.use('Agg')

代码是:

import mxnet as mx
import xml.etree.ElementTree as ET
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random

#解析指定的xml标签文件,得到所有objects的名字和边框信息
def parse_xml(xml_path):
    bbox = []
    tree = ET.parse(xml_path)
    root = tree.getroot()
    objects = root.findall('object') #得到一个xml文件中的所有目标,这个例子中有dog和person两个object
    for object in objects:
        name = object.find('name').text #object的名字,即dog或person
        bndbox = object.find('bndbox')#得到object的坐标信息
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)
        bbox_i = [name, xmin, ymin, xmax, ymax]
        bbox.append(bbox_i)
    return bbox
#根据从xml文件中得到的信息,标注object边框并生成一张图片实现可视化
def visualize_bbox(image, bbox, name):
    fig, ax = plt.subplots()
    plt.imshow(image)
    colors = dict()#指定标注某个对象的边框的颜色
    for bbox_i in bbox:
        cls_name = bbox_i[0] #得到object的name
        if cls_name not in colors:
            colors[cls_name] = (random.random(), random.random(), random.random()) #随机生成标注name为cls_name的object的边框颜色
        xmin = bbox_i[1]
        ymin = bbox_i[2]
        xmax = bbox_i[3]
        ymax = bbox_i[4]
#指明对应位置和大小的边框 rect
= patches.Rectangle(xy=(xmin, ymin), width=xmax-xmin, height=ymax-ymin, edgecolor=colors[cls_name],facecolor='None',linewidth=3.5) plt.text(xmin, ymin-2, '{:s}'.format(cls_name), bbox=dict(facecolor=colors[cls_name], alpha=0.5)) ax.add_patch(rect) plt.axis('off') plt.savefig('./{}_gt.png'.format(name)) #将该图片保存下来 plt.close() if __name__ == '__main__': name = '000001' img_path = '/opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007/JPEGImages/{}.jpg'.format(name) xml_path = '/opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007/Annotations/{}.xml'.format(name) bbox = parse_xml(xml_path=xml_path) image_string = open(img_path, 'rb').read() image = mx.image.imdecode(image_string, flag=1).asnumpy() visualize_bbox(image, bbox, name)

运行:

user@home:/opt/user/.../PASCAL_VOL_datasets$ python2 mxnet_9_1.py
/usr/local/lib/python2.7/dist-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

返回的图片000001_gt.png是:

mxnet深度学习实战学习笔记-9-目标检测

 

4)SSD算法简介

 SSD时目前应用非常广泛的目标检测算法,其网络结构如下图所示:

mxnet深度学习实战学习笔记-9-目标检测

该算法采用修改后的16层网络作为特征提取网络,修改内容主要是将2个全连接层(图中的FC6和FC7)替换成了卷积层(图中的Conv6和Conv7),另外将第5个池化层pool5改成不改变输入特征图的尺寸。然后在网络的后面(即Conv7后面)添加一系列的卷积层(Extra Feature Layer),即图中的Conv8_2、Conv9_2、Conv10_2和Conv11_2,这样就构成了SSD网络的主体结构。

这里要注意Conv8_2、Conv9_2、Conv10_2和Conv11_2并不是4个卷积层,而是4个小模块,就像是resent网络中的block一样。以Conv8_2为例,Conv8_2包含一个卷积核尺寸是1*1的卷积层和一个卷积核尺寸为3*3的卷积层,同时这2个卷积层后面都有relu类型的激活层。当然这4个模块还有一些差异,Conv8_2和Conv9_2的3*3卷积层的stride参数设置为2、pad参数设置为1,最终能够将输入特征图维度缩小为原来的一半;而Conv10_2和Conv11_2的3*3卷积层的stride参数设置为1、pad参数设置为0

在SSD算法中采用基于多个特征层进行预测的方式来预测目标框的位置,具体而言就是使用Conv4_3、Conv7、Conv8_2、Conv9_2、Conv10_2和Conv11_2这6个特征层的输出特征图来进行预测。假设输入图像大小是300*300,那么这6个特征层的输出特征图大小分别是38*38、19*19、10*10、5*5、3*3和1*1。每个特征层都会有目标类别的分类支路和目标位置的回归支路,这两个支路都是由特定 卷积核数量的卷积层构成的,假设在某个特征层的特征图上每个点设置了k个anchor,目标的类别数一共是N,那么分类支路的卷积核数量就是K*(N+1),其中1表示背景类别;回归支路的卷积核数量就是K*4,其中4表示坐标信息。最终将这6个预测层的分类结果和回归结果分别汇总到一起就构成整个网络的分类和回归结果

 

5)anchor

该名词最早出现在Faster RCNN系列论文中,表示一系列固定大小、宽高比的框,这些框均匀地分布在输入图像上, 而检测模型的目的就是基于这些anchor得到预测狂的偏置信息(offset),使得anchor加上偏置信息后得到的预测框能尽可能地接近真实目标框。在SSD论文中的default box名词,即默认框,和anchor的含义是类似的

 

MXNet框架提供了生成anchor的接口:mxnet.ndarray.contrib.MultiBoxPrior(),接下来通过具体数据演示anchor的含义

首先假设输入特征图大小是2*2,在SSD算法中会在特征图的每个位置生成指定大小和宽高比的anchor,大小的设定通过mxnet.ndarray.contrib.MultiBoxPrior()接口的sizes参数实现,而宽高比通过ratios参数实现,代码如下:

import mxnet as mx
import matplotlib.pyplot as plt
import matplotlib.patches as patches

input_h = 2
input_w = 2
input = mx.nd.random.uniform(shape=(1,3,input_h, input_w))
anchors = mx.ndarray.contrib.MultiBoxPrior(data=input, sizes=[0.3], ratios=[1])
print(anchors)

返回:

[[[0.09999999 0.09999999 0.4        0.4       ]
  [0.6        0.09999999 0.9        0.4       ]
  [0.09999999 0.6        0.4        0.9       ]
  [0.6        0.6        0.9        0.9       ]]]
<NDArray 1x4x4 @cpu(0)>

可见因为输入特征图大小是2*2,且设定的anchor大小和宽高比都只有1种,因此一共得到4个anchor,每个anchor都是1*4的向量,分别表示[xmin,ymin,xmax,ymax],也就是矩形框的左上角点坐标和右下角点坐标

 

接下来通过维度变换可以更清晰地看到anchor数量和输入特征维度的关系,最后一维的4表示每个anchor的4个坐标信息:

anchors = anchors.reshape((input_h, input_w, -1, 4))
print(anchors.shape)
anchors

返回:

(2, 2, 1, 4)

[[[[0.09999999 0.09999999 0.4        0.4       ]]

  [[0.6        0.09999999 0.9        0.4       ]]]


 [[[0.09999999 0.6        0.4        0.9       ]]

  [[0.6        0.6        0.9        0.9       ]]]]
<NDArray 2x2x1x4 @cpu(0)>

 

那么这4个anchor在输入图像上具体是什么样子?接下来将这些anchor显示在一张输入图像上,首先定义一个显示anchor的函数:

def plot_anchors(anchors, sizeNum, ratioNum): #sizeNum和ratioNum只是用于指明生成的图的不同anchor大小和高宽比
    img = mx.img.imread('./000001.jpg')
    height, width, _ = img.shape
    fig, ax = plt.subplots(1)
    ax.imshow(img.asnumpy())
    edgecolors = ['r', 'g', 'y', 'b']
    for h_i in range(anchors.shape[0]):
        for w_i in range(anchors.shape[1]):
            for index, anchor in enumerate(anchors[h_i, w_i, :, :].asnumpy()):
                xmin = anchor[0]*width
                ymin = anchor[1]*height
                xmax = anchor[2]*width
                ymax = anchor[3]*height
                rect = patches.Rectangle(xy=(xmin,ymin), width=xmax-xmin, 
                                         height=ymax-ymin,edgecolor=edgecolors[index],
                                        facecolor='None', linewidth=1.5)
                ax.add_patch(rect)
    plt.savefig('./mapSize_{}*{}_sizeNum_{}_ratioNum_{}.png'.format(anchors.shape[0], 
                                                                    anchors.shape[1], sizeNum, ratioNum))

调用函数:

plot_anchors(anchors, 1, 1)

返回:

mxnet深度学习实战学习笔记-9-目标检测

通过修改或增加anchor的宽高比及大小可以得到不同数量的anchor,比如增加宽高比为2和0.5的anchor

input_h = 2
input_w = 2
input = mx.nd.random.uniform(shape=(1,3,input_h, input_w))
anchors = mx.nd.contrib.MultiBoxPrior(data=input, sizes=[0.3],ratios=[1,2,0.5])
anchors = anchors.reshape((input_h, input_w, -1, 4))
print(anchors.shape)
plot_anchors(anchors, 1, 3)

返回:

(2, 2, 3, 4)

图为:

mxnet深度学习实战学习笔记-9-目标检测

输出结果说明在2*2的特征图上的每个点都生成了3个anchor

 

接下来再增加大小为0.4的anchor:

input_h = 2
input_w = 2
input = mx.nd.random.uniform(shape=(1,3,input_h, input_w))
anchors = mx.nd.contrib.MultiBoxPrior(data=input, sizes=[0.3,0.4],ratios=[1,2,0.5])
anchors = anchors.reshape((input_h, input_w, -1, 4))
print(anchors.shape)
plot_anchors(anchors, 2, 3)

返回:

(2, 2, 4, 4)

图为:

mxnet深度学习实战学习笔记-9-目标检测

说明在2*2大小的特征图上的每个点都生成了4个anchor,为什么得到的是4个,而不是2*3=6个呢?

因为在SSD论文中设定anchor时并不是组合所有设定的尺寸和宽高对比度值,而是分成2部分,一部分是针对每种宽高对比度都与其中一个尺寸size进行组合;另一部分是针对宽高对比度为1时,还会额外增加一个新尺寸与该宽高对比度进行组合

举例说明sizes=[s1, s2, ..., sm], ratios=[r1, r2,...,rn],计算得到的anchor数量为m+n-1,所以当m=2,n=3时,得到的anchor数就是4

首先第一部分就是sizes[0]会跟所有ratios组合,这就有n个anchor了;第二部分就是sizes[1:]会和ratios[0]组合,这样就有m-1个anchor了。对应这个例子就是[(0.3,1), (0.3,2), (0.3,0.5), (0.4,1)]。SSD论文中ratios参数的第一个值要设置为1

 

上面的例子使用的是2*2的特征图,下面改成5*5的特征图:

input_h = 5
input_w = 5
input = mx.nd.random.uniform(shape=(1,3,input_h, input_w))
anchors = mx.nd.contrib.MultiBoxPrior(data=input, sizes=[0.1,0.15],ratios=[1,2,0.5])
anchors = anchors.reshape((input_h, input_w, -1, 4))
print(anchors.shape)
plot_anchors(anchors, 2, 3)

返回:

(5, 5, 4, 4)

图为:

mxnet深度学习实战学习笔记-9-目标检测

需要说明的是上述代码中设定的anchor大小和特征图大小都是比较特殊的值,因此特征图上不同点之间的anchor都没有重叠,这是为了方便显示anchor而设置的。在实际的SSD算法中,特征图上不同点之间的anchor重叠特别多,因此基本上能够覆盖所有物体

SSD算法基于多个特征层进行目标的预测,这些特征层的特征图大小不一,因此设置的anchor大小也不一样,一般而言在网络的浅层部分特征图尺寸较大(如38*38、19*19),此时设置的anchor尺寸较小(比如0.1、0.2),主要用来检测小尺寸目标;在网络的深层部分特征图尺寸较小(比如3*3、1*1),此时设置的anchor尺寸较大(比如0.8、0.9),主要用来检测大尺寸目标

 

6)IoU

在目标检测算法中,我们经常需要评价2个矩形框之间的相似性,直观来看可以通过比较2个框的距离、重叠面积等计算得到相似性,而IoU指标恰好可以实现这样的度量。简而言之,IoU(intersection over union,交并比)是目标检测算法中用来评价2个矩形框之间相似度的指标

IoU = 两个矩形框相交的面积 / 两个矩形框相并的面积,如下图所示:

mxnet深度学习实战学习笔记-9-目标检测

其作用是在我们设定好了anchor后,需要判断每个anchor的标签,而判断的依据就是anchor和真实目标框的IoU。假设某个anchor和某个真实目标框的IoU大于设定的阈值,那就说明该anchor基本覆盖了这个目标,因此就可以认为这个anchor的类别就是这个目标的类别

另外在NMS算法中也需要用IoU指标界定2个矩形框的重合度,当2个矩形框的IoU值超过设定的阈值时,就表示二者是重复框

 

 7)模型训练目标

目标检测算法中的位置回归目标一直是该类算法中较难理解的部分, 一开始都会认为回归部分的训练目标就是真实框的坐标,其实不是。网络的回归支路的训练目标是offset,这个offset是基于真实框坐标和anchor坐标计算得到的偏置,而回归支路的输出值也是offset,这个offset是预测框坐标和anchor坐标之间的偏置。因此回归的目的就是让这个偏置不断地接近真实框坐标和anchor坐标之间的偏置

使用的接口是:mxnet.ndarray.contrib.MultiBoxTarget(),生成回归和分类的目标

import mxnet as mx
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_anchors(anchors, img, text, linestyle='-'): #定义可视化anchor或真实框的位置的函数
    height, width, _ = img.shape
    colors = ['r','y','b','c','m']
    for num_i in range(anchors.shape[0]):
        for index, anchor in enumerate(anchors[num_i,:,:].asnumpy()):
            xmin = anchor[0]*width
            ymin = anchor[1]*height
            xmax = anchor[2]*width
            ymax = anchor[3]*height
            rect = patches.Rectangle(xy=(xmin,ymin), width=xmax-xmin,
                                     height=ymax-ymin, edgecolor=colors[index],
                                     facecolor='None', linestyle=linestyle,
                                     linewidth=1.5)
            ax.text(xmin, ymin, text[index],
                    bbox=dict(facecolor=colors[index], alpha=0.5))
            ax.add_patch(rect)
#读取输入图像
img = mx.img.imread("./000001.jpg")
fig,ax = plt.subplots(1)
ax.imshow(img.asnumpy())
#在上面的输入图像上标明真实框的位置 ground_truth
= mx.nd.array([[[0, 0.136,0.48,0.552,0.742], #对应类别0 dog的真实框坐标值 [1, 0.023,0.024,0.997,0.996]]])#对应类别1 person的真实框坐标值 plot_anchors(anchors=ground_truth[:, :, 1:], img=img, text=['dog','person'])
#在上面的输入图像上标明anchor的位置
#坐标值表示[xmin, ymin, xmax, ymax] anchor
= mx.nd.array([[[0.1, 0.3, 0.4, 0.6], [0.15, 0.1, 0.85, 0.8], [0.1, 0.2, 0.6, 0.4], [0.25, 0.5, 0.55, 0.7], [0.05, 0.08, 0.95, 0.9]]]) plot_anchors(anchors=anchor, img=img, text=['1','2','3','4','5'], linestyle=':') #然后保存图片,图片如下图所示 plt.savefig("./anchor_gt.png")

 图为:

mxnet深度学习实战学习笔记-9-目标检测

接下来初始化一个分类预测值,维度是1*2*5,其中1表示图像数量,2表示目标类别,这里假设只有人和狗两个类别,5表示anchor数量,然后就可以通过mxnet.ndarray.contrib.MultiBoxTarget()接口获取模型训练的目标值。

该接口主要包含一下几个输入:

  • anchor :该参数在计算回归目标offset时需要用到
  • label:该参数在计算回归目标offset和分类目标时都用到
  • cls_pred :该参数内容其实在这里并未用到,因此只要维度符合要求即可
  • overlap_threshold: 该参数表示当预测框和真实框的IoU大于这个值时,该预测框的分类和回归目标就和该真实框对应
  • ignore_label :该参数表示计算回归目标时忽略的真实框类别标签,因为训练过程中一个批次有多张图像,每张图像的真实框数量都不一定相同,因此会采用全 -1 值来填充标签使得每张图像的真实标签维度相同,因此这里相当于忽略掉这些填充值
  • negative_mining_ratio :该参数表示在对负样本做过滤时设定的正负样本比例是1:3
  • variances :该参数表示计算回归目标时中心点坐标(x和y)的权重是0.1,宽和高的offset权重是0.2
cls_pred = mx.nd.array([[[0.4, 0.3, 0.2, 0.1, 0.1],
                        [0.6, 0.7, 0.8, 0.9, 0.9]]])
tmp = mx.nd.contrib.MultiBoxTarget(anchor=anchor, label=ground_truth,
                                  cls_pred=cls_pred, overlap_threshold=0.5,
                                  ignore_label=-1, negative_mining_ratio=3,
                                  variances=[0.1,0.1,0.2,0.2])
print("location target: {}".format(tmp[0]))
print("location target mask: {}".format(tmp[1]))
print("classification target: {}".format(tmp[2]))

 这里三个变量的含义是:

  • tmp[0] : 输出的是回归支路的训练目标,也就是我们希望模型的回归支路输出值和这个目标的smooth L1损失值要越小越好。可以看见tmp[0]的维度是1*20,其中1表示图像数量,20是4*5的意思,也就是5个anchor,每个anchor有4个坐标信息。另外tmp[0]中有部分是0,表示这些anchor都是负样本,也就是背景,可以从输出结果看出1号和3号anchor是背景
  • tmp[1]:输出的是回归支路的mask,该mask中对应正样本anchor的坐标用1填充,对应负样本anchor的坐标用0填充。该变量是在计算回归损失时用到,计算回归损失时负样本anchor是不参与计算的
  • tmp[2]:输出的是每个anchor的分类目标,在接口中默认类别0表示背景类,其他类别依次加1,因此dog类别就用类别1表示,person类别就用类别2表示

返回:

location target: 
[[ 0.          0.          0.          0.          0.14285699  0.8571425
   1.6516545   1.6413777   0.          0.          0.          0.
  -1.8666674   0.5499989   1.6345134   1.3501359   0.11111101  0.24390258
   0.3950827   0.8502576 ]]
<NDArray 1x20 @cpu(0)>
location target mask: 
[[0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]]
<NDArray 1x20 @cpu(0)>
classification target: 
[[0. 2. 0. 1. 2.]]
<NDArray 1x5 @cpu(0)>

所以从上面的结果我们可以知道,anchor 1和3是背景,anchor 2和5是person,4是dog

 

那么anchor的类别怎么定义呢?

在SSD算法中,首先每个真实框和N个anchor会计算到N个IoU,这N个IoU中的最大值对应的anchor就是正样本,而且类别就是这个真实框的类别。比如上面的图中与person这个真实框计算得到的IoU中最大的是5号anchor,所以5号anchor的分类目标就是person,也就是类别2,所以上面tmp[2][4]的值为2。同理,dog这个真实框的IoU最大的是4号anchor,因此4号anchor的分类目标就是dog,也就是类别1,所以上面的tmp[2][3]等于1。

除了IoU最大的anchor是正样本外,和真实框的IoU大于设定的IoU阈值的anchor也是正样本。这个阈值就是mxnet.ndarray.contrib.MultiBoxTarget()接口中的overlap_threshold参数设置的。显然可以看出2号anchor和person这个真实框的IoU大于设定的0.5阈值,因此2号anchor的预测类别为person,即类别2,tmp[2][1]等于2

 

关于回归目标的计算,在SSD论文中通过公式2已经介绍非常详细了。假设第i个anchor(用di表示),第j个真实框(用gi表示),那么回归目标就是如下这4个值:

mxnet深度学习实战学习笔记-9-目标检测

按照上面的公式得到的就是输出的tmp[0]的值

 

8)NMS

在目标检测算法中,我们希望每一个目标都有一个预测框准确地圈出目标的位置并给出预测类别。但是检测模型的输出预测框之间可能存在重叠,也就是说针对一个目标可能会有几个甚至几十个预测对的预测框,这显然不是我们想要的,因此就有了NMS操作

NMS(non maximum suprression,非极大值抑制)是目前目标检测算法常用的后处理操作,目的是去掉重复的预测框

NMS算法的过程大致如下:

假设网络输出的预测框中预测类别为person的框有K个,每个预测框都有1个预测类别、1个类别置信度和4个坐标相关的值。K个预测框中有N个预测框的类别置信度大于0。首先在N个框中找到类别置信度最大的那个框,然后计算剩下的N-1个框和选出来的这个框的IoU值,IoU值大于预先设定的阈值的框即为重复预测框(假设有M个预测框和选出来的框重复),剔除这M个预测框(这里将这M个预测框的类别置信度设置为0,表示剔除),保留IoU小于阈值的预测框。接下来再从N-1-M个预测框中找到类别置信度最大的那个框,然后计算剩下的N-2-M个框和选出来的这个框的IoU值,同样将IoU值大于预先设定的阈值的框剔除,保留IoU值小于阈值的框,然后再进行下一轮过滤,一直进行到所有框都过滤结束。最终保留的预测框就是输出结果,这样任意两个框的IoU都小于设定的IoU阈值,就达到去掉重复预测框的目的

mxnet深度学习实战学习笔记-9-目标检测

mxnet深度学习实战学习笔记-9-目标检测

mxnet深度学习实战学习笔记-9-目标检测

 

 

 9)评价指标mAP

 在目标检测算法中常用的评价指标是mAP(mean average precision),这是一个可以用来度量模型预测框类别和位置是否准确的指标。在目标检测领域常用的公开数据集PASCAL VOC中,有2种mAP计算方式,一种是针对PASCAL VOL 2007数据集的mAP计算方式,另一种是针对PASCAL VOC 2012数据集的mAP计算方式,二者差异较小,这里主要是用第一种

含义和计算过程如下:

假设某输入图像中有两个真实框:person和dog,模型输出的预测框中预测类别为person的框有5个,预测类别是dog的框有3个,此时的预测框已经经过NMS后处理。

首先以预测类别为person的5个框为例,先对这5个框按照预测的类别置信度进行从大到小排序,然后这5个值依次和person类别的真实框计算IoU值。假设IoU值大于预先设定的阈值(常设为0.5),那就说明这个预测框是对的,此时这个框就是TP(true positive);假设IoU值小于预先设定的阈值(常设为0.5),那就说明这个预测框是错的,此时这个框就是FP(false positive)。注意如果这5个预测框中有2个预测框和同一个person真实框的IoU大于阈值,那么只有类别置信度最大的那个预测框才算是预测对了,另一个算是FP

假设图像的真实框类别中不包含预测框类别,此时预测框类别是cat,但是图像的真实框只有person和dog,那么也算该预测框预测错了,为FP

FN(false negative)的计算可以通过图像中真实框的数量间接计算得到,因为图像中真实框的数量 = TP + FN

 

mxnet深度学习实战学习笔记-9-目标检测

癌症类别的精确度就是指模型判断为癌症且真实类别也为癌症的图像数量/模型判断为癌症的图像数量,计算公式如下图: 

mxnet深度学习实战学习笔记-9-目标检测

召回率是指模型判为癌症且真实类别也是癌症的图像数量/真实类别是癌症的图像数量,计算公式为:

mxnet深度学习实战学习笔记-9-目标检测

 

得到的person类精确度和召回率都是一个列表,列表的长度和预测类别为person的框相关,因此根据这2个列表就可以在一个坐标系中画出该类别的precision和recall曲线图

按照PASCAL VOL 2007的mAP计算方式,在召回率坐标轴均匀选取11个点(0, 0.1, ..., 0.9, 1),然后计算在召回率大于0的所有点中,精确度的最大值是多少;计算在召回率大于0.1的所有点中,精确度的最大值是多少;一直计算到在召回率大于1时,精确度的最大值是多少。这样我们最终得到11个精确度值,对这11个精确度求均值久得到AP了,因此AP中的A(average)就代表求精确度均值的过程

 

mAP和AP的关系是:

因为我们有两个类别,所以会得到person和dog两个类别对应的AP值,这样将这2个AP求均值就得到了mAP

所以如果有N个类别,就分别求这N个类别的AP值,然后求均值就得到mAP了

 

上面说到有两种mAP计算方式,两者的不同在于AP计算的不同,对于2012标准,是以召回率为依据计算AP

 

那么为什么可以使用mAP来评价目标检测的效果;

目标检测的效果取决于预测框的位置和类别是否准确,从mAP的计算过程中可以看出通过计算预测框和真实框的IoU来判断预测框是否准确预测到了位置信息,同时精确度和召回率指标的引用可以评价预测框的类别是否准确,因此mAP是目前目标检测领域非常常用的评价指标

 

4.通用目标检测

1)数据准备

当要在自定义数据集上训练检测模型时,只需要将自定义数据按照PASCAL VOC数据集的维护方式进行维护,就可以顺利进行

本节采用的训练数据包括VOC2007的trainval.txt和VOC2012的trainval.txt,一共16551张图像;验证集采用VOC2007的test.txt,一共4952张图像,这是常用的划分方式

另外为了使数据集更加通用,将VOC2007和VOC2012的trainval.txt文件(在/ImageSets/Main文件夹中)合并在一起,同时合并对应的图像文件夹JPEGImages和标签文件夹Annotations

 手动操作,创建一个新的文件夹PASCAL_VOC_converge将需要的文件在这里合并

首先合并JPEGImages

cp -r /opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007/JPEGImages/. /opt/user/.../PASCAL_VOC_converge/JPEGImages
cp -r /opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2012/JPEGImages/. /opt/user/.../PASCAL_VOC_converge/JPEGImages

然后合并Annotations:

cp -r /opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2007/Annotations/. /opt/user/.../PASCAL_VOC_converge/Annotations
cp -r /opt/user/.../PASCAL_VOL_datasets/VOCdevkit/VOC2012/Annotations/. /opt/user/.../PASCAL_VOC_converge/Annotations

然后将两者的trainval.txt内容放在同一个trainval.txt中,然后要将trainval.txt和VOC2007的test.txt都放进新创建的lst文件夹中

最后新的文件夹PASCAL_VOC_converge中的文件为:

user@home:/opt/user/.../PASCAL_VOC_converge$ ls
Annotations  create_list.py  JPEGImages  lst
user@home:/opt/user/.../PASCAL_VOC_converge/lst$ ls
test.txt  trainval.txt

 

然后接下来开始基于数据生成.lst文件和RecordIO文件,生成.lst文件的脚本为create_list.py:

import os
import argparse
from PIL import Image
import xml.etree.ElementTree as ET
import random

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--set', type=str, default='train')
    parser.add_argument('--save-path', type=str, default='')
    parser.add_argument('--dataset-path', type=str, default='')
    parser.add_argument('--shuffle', type=bool, default=False)
    args = parser.parse_args()
    return args

def main():
    label_dic = {"aeroplane": 0, "bicycle": 1, "bird": 2, "boat": 3, "bottle": 4, "bus": 5,
                 "car": 6, "cat": 7, "chair": 8, "cow": 9, "diningtable": 10, "dog": 11,
                 "horse": 12, "motorbike": 13, "person": 14, "pottedplant": 15, "sheep": 16,
                 "sofa": 17, "train": 18, "tvmonitor": 19}
    args = parse_args()
    if not os.path.exists(os.path.join(args.save_path, "{}.lst".format(args.set))):
        os.mknod(os.path.join(args.save_path, "{}.lst".format(args.set)))
    with open(os.path.join(args.save_path, "{}.txt".format(args.set)), "r") as input_file:
        lines = input_file.readlines()
        if args.shuffle:
            random.shuffle(lines)
        with open(os.path.join(args.save_path, "{}.lst".format(args.set)), "w") as output_file:
            index = 0
            for line in lines:
                line = line.strip()
                out_str = "\t".join([str(index), "2", "6"])
                img = Image.open(os.path.join(args.dataset_path, "JPEGImages", line+".jpg"))
                width, height = img.size
                xml_path = os.path.join(args.dataset_path, "Annotations", line+".xml")
                tree = ET.parse(xml_path)
                root = tree.getroot()
                objects = root.findall('object')
                for object in objects:
                    name = object.find('name').text
                    difficult = ("%.4f" % int(object.find('difficult').text))
                    label_idx = ("%.4f" % label_dic[name])
                    bndbox = object.find('bndbox')
                    xmin = ("%.4f" % (int(bndbox.find('xmin').text)/width))
                    ymin = ("%.4f" % (int(bndbox.find('ymin').text)/height))
                    xmax = ("%.4f" % (int(bndbox.find('xmax').text)/width))
                    ymax = ("%.4f" % (int(bndbox.find('ymax').text)/height))
                    object_str = "\t".join([label_idx, xmin, ymin, xmax, ymax, difficult])
                    out_str = "\t".join([out_str, object_str])
                out_str = "\t".join([out_str, "{}/JPEGImages/".format(args.dataset_path.split("/")[-1])+line+".jpg"+"\n"])
                output_file.writelines(out_str)
                index += 1

if __name__ == '__main__':
    main()

命令为:

user@home:/opt/user/.../PASCAL_VOC_converge$ python create_list.py --set test --save-path /opt/user/.../PASCAL_VOC_converge/lst --dataset-path /opt/user/.../PASCAL_VOC_converge

user@home:/opt/user/.../PASCAL_VOC_converge$ python create_list.py --set trainval --save-path /opt/user/.../PASCAL_VOC_converge/lst --dataset-path /opt/user/.../PASCAL_VOC_converge --shuffle True

然后查看可见/opt/user/.../PASCAL_VOC_converge/lst文件夹下生成了相应的trainval.lst和test.lst文件:

user@home:/opt/user/.../PASCAL_VOC_converge/lst$ ls
test.lst  test.txt  trainval.lst  trainval.txt

对上面的命令进行说明:

  • --set:用来指定生成的列表文件的名字,如test说明是用来生成test.txt文件指明的测试的数据集的.lst文件,生成的文件名为test.lst,trainval则生成trainval.txt文件指明的训练验证数据集的.lst文件,生成的文件名为trainval.lst
  • --save-path:用来指定生成的.lst文件的保存路径,trainval.txt和test.txt文件要保存在该路径下
  • --dataset-path:用来指定数据集的根目录

截取train.lst文件中的一个样本的标签介绍.lst文件的内容,如下:

0       2       6       6.0000  0.4300  0.4853  0.7540  0.6400  0.0000  PASCAL_VOC_converge/JPEGImages/2008_000105.jpg

该图为:

mxnet深度学习实战学习笔记-9-目标检测

列与列之间都是采用Tab键进行分割的。

  • 第一列是index,即图像的标号,默认从0开始,然后递增。
  • 第二列表示标识符位数,这里第二列的值都为2,因为标识符有2位,也就是第2列和第3列都是标识符,不是图像标签
  • 第三列表示每个目标的标签位数,这里第三列都是6,表示每个目标的标签都是6个数字
  • 第4列到第9列这6个数字就是第一个目标的标签,其中6表示该目标的类别,即'car'(PASCAL VOC数据集又20类,该列值为0-19);接下来的4个数字(0.4300 0.4853 0.7540 0.6400)表示目标的位置,即(xmin, ymin, xmax, ymax);第9列的值表示是否是difficult,如果为0则表示该目标能正常预测,1则表示该目标比较难检测

如果还有第二个目标的话,那么在第一个目标后面就会接着第二个目标的6列信息,依此类推

  • 最后一列是图像的路径

多个目标可见:

18      2       6       14.0000 0.5620  0.0027  1.0000  1.0000  0.0000  19.0000 0.0680  0.3013  0.5760  0.9653  0.0000  PASCAL_VOC_conve
rge/JPEGImages/2008_004301.jpg  

图2008_004301.jpg为:

mxnet深度学习实战学习笔记-9-目标检测

目标分别是类别14的'person'和类别19的'tvmonitor'

 

生成.lst文件后,就可以基于.lst文件和图像文件生成RecordIO文件了

使用脚本im2rec.py:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import os
import sys

curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time
import traceback

try:
    import multiprocessing
except ImportError:
    multiprocessing = None

def list_image(root, recursive, exts):
    """Traverses the root of directory that contains images and
    generates image list iterator.
    Parameters
    ----------
    root: string
    recursive: bool
    exts: string
    Returns
    -------
    image iterator that contains all the image under the specified path
    """

    i = 0
    if recursive:
        cat = {}
        for path, dirs, files in os.walk(root, followlinks=True):
            dirs.sort()
            files.sort()
            for fname in files:
                fpath = os.path.join(path, fname)
                suffix = os.path.splitext(fname)[1].lower()
                if os.path.isfile(fpath) and (suffix in exts):
                    if path not in cat:
                        cat[path] = len(cat)
                    yield (i, os.path.relpath(fpath, root), cat[path])
                    i += 1
        for k, v in sorted(cat.items(), key=lambda x: x[1]):
            print(os.path.relpath(k, root), v)
    else:
        for fname in sorted(os.listdir(root)):
            fpath = os.path.join(root, fname)
            suffix = os.path.splitext(fname)[1].lower()
            if os.path.isfile(fpath) and (suffix in exts):
                yield (i, os.path.relpath(fpath, root), 0)
                i += 1

def write_list(path_out, image_list):
    """Hepler function to write image list into the file.
    The format is as below,
    integer_image_index \t float_label_index \t path_to_image
    Note that the blank between number and tab is only used for readability.
    Parameters
    ----------
    path_out: string
    image_list: list
    """
    with open(path_out, 'w') as fout:
        for i, item in enumerate(image_list):
            line = '%d\t' % item[0]
            for j in item[2:]:
                line += '%f\t' % j
            line += '%s\n' % item[1]
            fout.write(line)

def make_list(args):
    """Generates .lst file.
    Parameters
    ----------
    args: object that contains all the arguments
    """
    image_list = list_image(args.root, args.recursive, args.exts)
    image_list = list(image_list)
    if args.shuffle is True:
        random.seed(100)
        random.shuffle(image_list)
    N = len(image_list)
    chunk_size = (N + args.chunks - 1) // args.chunks
    for i in range(args.chunks):
        chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
        if args.chunks > 1:
            str_chunk = '_%d' % i
        else:
            str_chunk = ''
        sep = int(chunk_size * args.train_ratio)
        sep_test = int(chunk_size * args.test_ratio)
        if args.train_ratio == 1.0:
            write_list(args.prefix + str_chunk + '.lst', chunk)
        else:
            if args.test_ratio:
                write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
            if args.train_ratio + args.test_ratio < 1.0:
                write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
            write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])

def read_list(path_in):
    """Reads the .lst file and generates corresponding iterator.
    Parameters
    ----------
    path_in: string
    Returns
    -------
    item iterator that contains information in .lst file
    """
    with open(path_in) as fin:
        while True:
            line = fin.readline()
            if not line:
                break
            line = [i.strip() for i in line.strip().split('\t')]
            line_len = len(line)
            # check the data format of .lst file
            if line_len < 3:
                print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line))
                continue
            try:
                item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
            except Exception as e:
                print('Parsing lst met error for %s, detail: %s' % (line, e))
                continue
            yield item

def image_encode(args, i, item, q_out):
    """Reads, preprocesses, packs the image and put it back in output queue.
    Parameters
    ----------
    args: object
    i: int
    item: list
    q_out: queue
    """
    fullpath = os.path.join(args.root, item[1])

    if len(item) > 3 and args.pack_label:
        header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
    else:
        header = mx.recordio.IRHeader(0, item[2], item[0], 0)

    if args.pass_through:
        try:
            with open(fullpath, 'rb') as fin:
                img = fin.read()
            s = mx.recordio.pack(header, img)
            q_out.put((i, s, item))
        except Exception as e:
            traceback.print_exc()
            print('pack_img error:', item[1], e)
            q_out.put((i, None, item))
        return

    try:
        img = cv2.imread(fullpath, args.color)
    except:
        traceback.print_exc()
        print('imread error trying to load file: %s ' % fullpath)
        q_out.put((i, None, item))
        return
    if img is None:
        print('imread read blank (None) image for file: %s' % fullpath)
        q_out.put((i, None, item))
        return
    if args.center_crop:
        if img.shape[0] > img.shape[1]:
            margin = (img.shape[0] - img.shape[1]) // 2
            img = img[margin:margin + img.shape[1], :]
        else:
            margin = (img.shape[1] - img.shape[0]) // 2
            img = img[:, margin:margin + img.shape[0]]
    if args.resize:
        if img.shape[0] > img.shape[1]:
            newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
        else:
            newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
        img = cv2.resize(img, newsize)

    try:
        s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
        q_out.put((i, s, item))
    except Exception as e:
        traceback.print_exc()
        print('pack_img error on file: %s' % fullpath, e)
        q_out.put((i, None, item))
        return

def read_worker(args, q_in, q_out):
    """Function that will be spawned to fetch the image
    from the input queue and put it back to output queue.
    Parameters
    ----------
    args: object
    q_in: queue
    q_out: queue
    """
    while True:
        deq = q_in.get()
        if deq is None:
            break
        i, item = deq
        image_encode(args, i, item, q_out)

def write_worker(q_out, fname, working_dir):
    """Function that will be spawned to fetch processed image
    from the output queue and write to the .rec file.
    Parameters
    ----------
    q_out: queue
    fname: string
    working_dir: string
    """
    pre_time = time.time()
    count = 0
    fname = os.path.basename(fname)
    fname_rec = os.path.splitext(fname)[0] + '.rec'
    fname_idx = os.path.splitext(fname)[0] + '.idx'
    record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
                                           os.path.join(working_dir, fname_rec), 'w')
    buf = {}
    more = True
    while more:
        deq = q_out.get()
        if deq is not None:
            i, s, item = deq
            buf[i] = (s, item)
        else:
            more = False
        while count in buf:
            s, item = buf[count]
            del buf[count]
            if s is not None:
                record.write_idx(item[0], s)

            if count % 1000 == 0:
                cur_time = time.time()
                print('time:', cur_time - pre_time, ' count:', count)
                pre_time = cur_time
            count += 1

def parse_args():
    """Defines all arguments.
    Returns
    -------
    args object that contains all the params
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Create an image list or \
        make a record database by reading from an image list')
    parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
    parser.add_argument('root', help='path to folder containing images.')

    cgroup = parser.add_argument_group('Options for creating image lists')
    cgroup.add_argument('--list', action='store_true',
                        help='If this is set im2rec will create image list(s) by traversing root folder\
        and output to <prefix>.lst.\
        Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')
    cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
                        help='list of acceptable image extensions.')
    cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
    cgroup.add_argument('--train-ratio', type=float, default=1.0,
                        help='Ratio of images to use for training.')
    cgroup.add_argument('--test-ratio', type=float, default=0,
                        help='Ratio of images to use for testing.')
    cgroup.add_argument('--recursive', action='store_true',
                        help='If true recursively walk through subdirs and assign an unique label\
        to images in each folder. Otherwise only include images in the root folder\
        and give them label 0.')
    cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',
                        help='If this is passed, \
        im2rec will not randomize the image order in <prefix>.lst')
    rgroup = parser.add_argument_group('Options for creating database')
    rgroup.add_argument('--pass-through', action='store_true',
                        help='whether to skip transformation and save image as is')
    rgroup.add_argument('--resize', type=int, default=0,
                        help='resize the shorter edge of image to the newsize, original images will\
        be packed by default.')
    rgroup.add_argument('--center-crop', action='store_true',
                        help='specify whether to crop the center image to make it rectangular.')
    rgroup.add_argument('--quality', type=int, default=95,
                        help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
    rgroup.add_argument('--num-thread', type=int, default=1,
                        help='number of thread to use for encoding. order of images will be different\
        from the input list if >1. the input list will be modified to match the\
        resulting order.')
    rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
                        help='specify the color mode of the loaded image.\
        1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
        0: Loads image in grayscale mode.\
        -1:Loads image as such including alpha channel.')
    rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
                        help='specify the encoding of the images.')
    rgroup.add_argument('--pack-label', action='store_true',
        help='Whether to also pack multi dimensional label in the record file')
    args = parser.parse_args()
    args.prefix = os.path.abspath(args.prefix)
    args.root = os.path.abspath(args.root)
    return args

if __name__ == '__main__':
    args = parse_args()
    # if the '--list' is used, it generates .lst file
    if args.list:
        make_list(args)
    # otherwise read .lst file to generates .rec file
    else:
        if os.path.isdir(args.prefix):
            working_dir = args.prefix
        else:
            working_dir = os.path.dirname(args.prefix)
        files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
                    if os.path.isfile(os.path.join(working_dir, fname))]
        count = 0
        for fname in files:
            if fname.startswith(args.prefix) and fname.endswith('.lst'):
                print('Creating .rec file from', fname, 'in', working_dir)
                count += 1
                image_list = read_list(fname)
                # -- write_record -- #
                if args.num_thread > 1 and multiprocessing is not None:
                    q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
                    q_out = multiprocessing.Queue(1024)
                    # define the process
                    read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
                                    for i in range(args.num_thread)]
                    # process images with num_thread process
                    for p in read_process:
                        p.start()
                    # only use one process to write .rec to avoid race-condtion
                    write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
                    write_process.start()
                    # put the image list into input queue
                    for i, item in enumerate(image_list):
                        q_in[i % len(q_in)].put((i, item))
                    for q in q_in:
                        q.put(None)
                    for p in read_process:
                        p.join()

                    q_out.put(None)
                    write_process.join()
                else:
                    print('multiprocessing not available, fall back to single threaded encoding')
                    try:
                        import Queue as queue
                    except ImportError:
                        import queue
                    q_out = queue.Queue()
                    fname = os.path.basename(fname)
                    fname_rec = os.path.splitext(fname)[0] + '.rec'
                    fname_idx = os.path.splitext(fname)[0] + '.idx'
                    record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
                                                           os.path.join(working_dir, fname_rec), 'w')
                    cnt = 0
                    pre_time = time.time()
                    for i, item in enumerate(image_list):
                        image_encode(args, i, item, q_out)
                        if q_out.empty():
                            continue
                        _, s, _ = q_out.get()
                        record.write_idx(item[0], s)
                        if cnt % 1000 == 0:
                            cur_time = time.time()
                            print('time:', cur_time - pre_time, ' count:', cnt)
                            pre_time = cur_time
                        cnt += 1
        if not count:
            print('Did not find and list file with prefix %s'%args.prefix)
View Code

相关文章:

  • 2022-01-01
  • 2021-04-23
  • 2021-05-25
  • 2021-05-31
  • 2021-08-27
  • 2022-12-23
  • 2022-12-23
猜你喜欢
  • 2021-11-02
  • 2021-10-29
  • 2021-12-08
  • 2021-09-07
  • 2021-10-19
  • 2021-07-03
相关资源
相似解决方案