这是一篇来自Google DeepMind 2015年发表的一篇文章,由于最近在看的自然场景中的Curved text需要将Curved text 通过STN变换成tightly-bounded的文本,这样变换之后的文本对于后续的文本识别的效果更好。

    众所周知,CNN(卷积神经网络)是一个很强大的分类网络,但是对于input data 还是缺少spatially invariant的能力,通常来说,要使CNN具有spatially invariant的能力,一般都是对Training data 做一个data argument 的操作。这里作者通过在传统的CNN模型中加入了空间变换网络,就无需做data argument 的过程,这样Training data会大大的减少,同时该网络也具有invariance to translation, scale, rotation and more generic warping等性质,因为该网络会自动将输入的各种类型的目标(例如curved text自动转换成tightly-bounded的文本),或者对于手写字的图片,会做如下所示的变换。

 Spatial Transformer Networks 空间变换网络   

图1

图1所示,对于一个distorted Mnist  digit classification问题,将STN 放入到这个fully-connected network的第一层会出现的效果如上所示,(a)图是输入图片(注意STN网络不只是针对image,还有feature map,这里是对image进行处理)可以看出这个些图片是distorted with random translation, scale, rotation, and clutter,(b)是STN中的网络预测的transformation(也就是预测的参数),(c)是运用了空间变换之后的结果图片,(d)是分类结果。

    下面来看下具体的STN(空间变换网络)的具体结构。如下图所示。

Spatial Transformer Networks 空间变换网络

图2

U代表feature map ,V代表经过空间变换之后的feature map ,U 输入到STN中,第一步:STN中的localisation network 预测出transformation parameters θ(变换参数),第二步:Tθ(G)中的G是regular spatial grid(V中feature map中的一个二维坐标),Tθ()代表某种变换,Tθ(G)是代表的U中的sampling grid(是input map 中的点,output map中的各个点的像素值是通过这些点采样到的),第三步:以Sampling grid 和U 作为输入,输入到Sampler中得到结果V

1.Localisation Network

输入的feature map 的维度为Spatial Transformer Networks 空间变换网络,H代表feature map 的Height,W代表Width,C代表Channel(对于下面讨论的都是C=1的情况,也就是输入单个Channel feature map,输出一个 变换之后的feature map,对于C不为1的情况同样讨论)输出的结果为θSpatial Transformer Networks 空间变换网络,其中U为输入特征图,f loc()代表Localisation network 。

2.Parameterised Sampling Grid

这部分是产生Simpling grid的过程,前面提到Spatial Transformer Networks 空间变换网络代表的是output map 中的所有点的集合,其中Spatial Transformer Networks 空间变换网络代表的是每个点的坐标(其实STN的目的就是得到这些点上的像素值),这里考虑的是,对于如何产生input map 上的坐标,如下所示:

Spatial Transformer Networks 空间变换网络

图3

其中Spatial Transformer Networks 空间变换网络代表的是U上的坐标,Spatial Transformer Networks 空间变换网络代表的是2D 仿射变换(Spatial Transformer Networks 空间变换网络代表的是某种Transformation 操作,这里用的是2D仿射变换),Spatial Transformer Networks 空间变换网络代表的是V上的坐标(这里做的就是通过Spatial Transformer Networks 空间变换网络找到U上的坐标点Spatial Transformer Networks 空间变换网络),这些Spatial Transformer Networks 空间变换网络所组成的集合叫做Sampling Grid,这就得到了我们想要的结果(这时只是V的坐标已知,而V里面的像素值并不知道)下面介绍如何产生V中的像素值。

Spatial Transformer Networks 空间变换网络

图4

上图可以清楚的说明不同的变换参数(Localisation network获得的)得到的不同的结果,(a)中是经过参数变换之后获得的结果,其中U中的Sampling grid (蓝色的点)是通过变换得来的(I *Spatial Transformer Networks 空间变换网络这里Sampling grid 与V中的Regular grid是一样的,等于没有做变换),然后通过Sampler(第三部分介绍) 得到结果V。(b)中的结果是通过仿射变换得到结果,其中U中的Sampling grid 是通过(Spatial Transformer Networks 空间变换网络*Spatial Transformer Networks 空间变换网络)得到的,与(a)类似,经过Sampler 得到最后的结果。

3.Differentiable Image Sampling

      要实现Sampling 的功能,也就是将input map 进行转换得到output map 的结果,需要以第二部分得到的Sampling grid 和input map 作为输入,这样就能通过sampling kernel对input map中的像素进行采样,将得到的像素值放入对应的output map中相应的坐标位置。

Spatial Transformer Networks 空间变换网络

上式中Spatial Transformer Networks 空间变换网络代表在某个channel(这里讨论C=1的情况)的feature map(output map)中某个位置i的像素值,Spatial Transformer Networks 空间变换网络代表Sampling kernel ,它可以通过插值(例如双线性插值实现),Spatial Transformer Networks 空间变换网络Spatial Transformer Networks 空间变换网络代表Sampling kernel Spatial Transformer Networks 空间变换网络的参数。Spatial Transformer Networks 空间变换网络代表的input  map中坐标(n,m)位置的像素值,在上述式子(3)中Spatial Transformer Networks 空间变换网络是已知的,Spatial Transformer Networks 空间变换网络也是已知的,因此其意义就是找到与Spatial Transformer Networks 空间变换网络一定范围的点的像素值进行加和,从而得到Spatial Transformer Networks 空间变换网络,一定范围是由Spatial Transformer Networks 空间变换网络决定的,例如,如下方程式(4)所示:

Spatial Transformer Networks 空间变换网络

这里Spatial Transformer Networks 空间变换网络用了Kronecker delta 函数,目的就是取得离Spatial Transformer Networks 空间变换网络最近的位置为(n,m)的像素值Spatial Transformer Networks 空间变换网络,将其放入对应的output map的Spatial Transformer Networks 空间变换网络,此时这个像素值Spatial Transformer Networks 空间变换网络=Spatial Transformer Networks 空间变换网络,当然作者还介绍了另外一种Sampling kernel,就是双线性插值函数,如下(5)所示:

Spatial Transformer Networks 空间变换网络

该方程的意思是获得离Spatial Transformer Networks 空间变换网络距离不超过1的m,和离Spatial Transformer Networks 空间变换网络距离不超过1的n的所有满足条件的像素值Spatial Transformer Networks 空间变换网络的加和。最后作者为了说明STN是可以通过backpropagation来训练的,给出了如下公式:

Spatial Transformer Networks 空间变换网络

由于Sampling function 的不连续,这里需要使用sub-gradients更新参数。最后附上一张比较直观的图,如下所示:

Spatial Transformer Networks 空间变换网络

图5

总结:STN 能够在没有标注关键点的情况下,根据任务自己学习图片或特征的空间变换参数,将输入图片或者学习的特征在空间上进行对齐,从而减少物体由于空间中的旋转、平移、尺度、扭曲等几何变换对分类、定位等任务的影响。加入到已有的CNN或者FCN网络,能够提升网络的学习能力【4】

这里有个问题需要说明下:

1.为什么在进行仿射变换的时候不是通过Input map上的坐标来生成Output map上的坐标?而是通过Output map上的点来生成 Input map上的坐标。

答:(1)这是由于通过Input  map 坐标仿射变换生成 Output map的时候(左图生成右图),得到的Output map的坐标值可能不是整点数,这时候Input对应的Output虽然有像素值,但这时,像素值应该放在哪个格子里呢?如下图红点所示,假设第二排左边的一个×不在格子里面,那我们得到的这个坐标(也就是有图的红点)也就没什么意义。(2)(右图生成左图)那么有人可能会问,那么这样的话,当你通过Output map的坐标得到Input map的坐标也可能会得到不是整点的坐标值,这个回答是肯定的,例如第三排左边的第一个×,在仿射变换完事之后得到Input map上的点坐标就不在整点上,因此这时候就可以通过对Input map上的该点做双线性插值,就可以得到Input map上该点坐标近似的像素值了,放入Output map相应的格子中即可

Spatial Transformer Networks 空间变换网络Spatial Transformer Networks 空间变换网络

参考博客如下:

如对仿射变换不是很熟的话建议看下这个博主的博客,个人觉得不错。

https://blog.csdn.net/sinat_34474705/article/details/75125520

1.https://blog.csdn.net/xbinworld/article/details/69049680

2.http://www.cnblogs.com/neopenx/p/4851806.html

3.原文链接   https://arxiv.org/pdf/1506.02025.pdf

4.https://blog.csdn.net/shaoxiaohu1/article/details/51809605

5.https://blog.csdn.net/sinat_34474705/article/details/75268248


相关文章:

  • 2021-11-02
  • 2021-11-06
  • 2021-06-20
  • 2021-12-12
  • 2021-08-07
猜你喜欢
  • 2021-09-09
  • 2021-11-27
  • 2021-12-18
  • 2021-11-05
  • 2021-09-27
相关资源
相似解决方案