【问题标题】:How to calculate output of Keras UpSampling2D Layer?如何计算 Keras UpSampling2D 层的输出?
【发布时间】:2020-06-26 07:21:21
【问题描述】:

我不明白 Keras 中 Upsampling2d 层的输出是如何计算的。 举个例子:

img_input = Input((2,2, 1))
out = UpSampling2D(size=2, interpolation="bilinear")(img_input)
model = Model(img_input, out, name='test')

input = np.array([[100, 200], [6, 8]]).reshape(1, 2, 2, 1)
model.predict(input).reshape(4, 4)

这样的结果是:

array([[100. , 150. , 200. , 200. ],
       [ 53. ,  78.5, 104. , 104. ],
       [  6. ,   7. ,   8. ,   8. ],
       [  6. ,   7. ,   8. ,   8. ]], dtype=float32)

对我来说,双线性插值会得到 s.th。不同的。让我们以第一行的 150 为例。对我来说,这实际上应该是100*(2/3) + 200*(1/3) = 133.33。这一层有什么不同?

谢谢!

【问题讨论】:

    标签: image tensorflow machine-learning keras conv-neural-network


    【解决方案1】:

    根据您的输入数组,这些是应用size=2 的双线性上采样的步骤:

    # input array
    [[100, 200],
     [  6,   8]]
    
    # double the size and fill with existing values spreading them evenly.
    #  Important! the edges of the array are not filled:
    [[100,   _, 200,   _],
     [  _,   _,   _,   _],
     [  6,   _,   8,   _],
     [  _,   _,   _,   _]]
    
    # Start filling the empty spaces sequentially by applying this rule:
    #   Empty spaces surrounded by one or more filled values are filled with the
    #       arithmetic average of these values.
    # We would fill the entire array in two steps:
    # 1. The first step would look like this:
    [[100, 150, 200, 200],
     [ 53,   _, 104,   _],
     [  6,   7,   8,   8],
     [  6,   _,   8,   _]]
    # 2. The second step would look like this:
    [[100, 150, 200, 200],
     [ 53,78.5, 104, 104],
     [  6,   7,   8,   8],
     [  6,   7,   8,   8]]
    

    如果您希望在第一行中获得 [100, 133, 166, 200](并相应填充数组的其余部分),您应该生成 size=3 的上采样,然后移除边缘(res[1 :5, 1:5]):

    img_input = Input((2,2, 1))
    out = UpSampling2D(size=3, interpolation="bilinear")(img_input)
    model = Model(img_input, out, name='test')
    
    input = np.array([[100, 200], [6, 8]]).reshape(1, 2, 2, 1)
    model.predict(input).reshape(6, 6)[1:5, 1:5]
    >> array([[100.       , 133.33334  , 166.66667  , 200.       ],
           [ 68.666664 ,  91.111115 , 113.55555  , 136.       ],
           [ 37.333324 ,  48.888878 ,  60.444427 ,  71.999985 ],
           [  6.       ,   6.666667 ,   7.3333335,   8.       ]],
          dtype=float32)
    

    【讨论】: