【问题标题】:Why the exactly identical keras model predict different results for the same input data in the same env为什么完全相同的 keras 模型预测相同环境中相同输入数据的不同结果
【发布时间】:2021-01-27 00:14:45
【问题描述】:

我有两个模型被证明是相同的,如下所示:

if len(m_s.layers) != len(m_m.layers):
    print("number of layers are different")
    
for i in range(len(m_s.layers)):
    weight_s = m_s.layers[i].get_weights()
    weight_m = m_m.layers[i].get_weights()

    if len(weight_s) > 0:
        for j in range(len(weight_s)):
            
            if (weight_s[j] == weight_m[j]).all:
                print("layer %d identical" % i)
            else:
                print("!!!!! layer %d not the same" % i)
    else:
        if len(weight_m) == 0:
            print("layer %d identical" % i)
        else:
            print("!!!!! layer %d not the same" % i)

并且输出显示它们是相同的。它们是来自 imagenet 模型的切片。

layer 0 identical
layer 1 identical
layer 2 identical
layer 2 identical
layer 2 identical
layer 2 identical
layer 3 identical
layer 4 identical
layer 5 identical
layer 5 identical
layer 5 identical
layer 5 identical
layer 6 identical
layer 7 identical
layer 8 identical
layer 8 identical
layer 8 identical
layer 8 identical
layer 9 identical
layer 10 identical
layer 10 identical
layer 10 identical
layer 10 identical
layer 11 identical
layer 12 identical
layer 13 identical
layer 13 identical
layer 13 identical
layer 13 identical
layer 14 identical
layer 15 identical
layer 16 identical
layer 16 identical
layer 16 identical
layer 16 identical
layer 17 identical
layer 18 identical
layer 18 identical
layer 18 identical
layer 18 identical
layer 19 identical
layer 20 identical
layer 21 identical
layer 21 identical
layer 21 identical
layer 21 identical
layer 22 identical
layer 23 identical
layer 24 identical
layer 24 identical
layer 24 identical
layer 24 identical
layer 25 identical
layer 26 identical
layer 27 identical
layer 27 identical
layer 27 identical
layer 27 identical
layer 28 identical
layer 29 identical
layer 30 identical
layer 30 identical
layer 30 identical
layer 30 identical
layer 31 identical
layer 32 identical
layer 33 identical
layer 33 identical
layer 33 identical
layer 33 identical
layer 34 identical
layer 35 identical
layer 35 identical
layer 35 identical
layer 35 identical
layer 36 identical
layer 37 identical
layer 38 identical
layer 38 identical
layer 38 identical
layer 38 identical
layer 39 identical
layer 40 identical
layer 41 identical
layer 41 identical
layer 41 identical
layer 41 identical
layer 42 identical
layer 43 identical
layer 44 identical
layer 44 identical
layer 44 identical
layer 44 identical
layer 45 identical
layer 46 identical
layer 47 identical
layer 47 identical
layer 47 identical
layer 47 identical
layer 48 identical
layer 49 identical
layer 50 identical
layer 50 identical
layer 50 identical
layer 50 identical
layer 51 identical

但是,当我在同一台机器和同一环境中使用这两个模型来预测相同的输入数据时,输出完全不同。

m_s.predict(data)

输出

array([[[[-2.2014694e+00, -7.4636793e+00, -3.7543521e+00, ...,
           4.2393379e+00,  7.2923303e+00, -7.9203067e+00],
         [-6.8980045e+00, -6.7517347e+00,  5.9752476e-01, ...,
           2.2391853e+00, -2.0161586e+00, -7.5054851e+00],
         [-4.4470978e+00, -4.2420959e+00, -3.9374633e+00, ...,
           5.9843721e+00,  5.4481273e+00, -2.7136576e+00],
         ...,
         [-8.2077494e+00, -5.5874801e+00,  2.2708473e+00, ...,
          -2.5585687e-01,  4.0198727e+00, -4.5880938e+00],
         [-7.5793233e+00, -6.3811040e+00,  3.7389126e+00, ...,
           1.7169635e+00, -3.4249902e-01, -7.1873198e+00],
         [-8.2512989e+00, -4.2883468e+00, -2.7908459e+00, ...,
           3.9796615e+00,  4.7512245e-01, -4.5338011e+00]],
        [[-5.2522459e+00, -5.2272692e+00, -3.7313356e+00, ...,
           1.0820831e+00, -1.9317195e+00, -8.3177958e+00],
         [-5.8229809e+00, -6.8049965e+00, -1.4538713e+00, ...,
           4.0576010e+00, -1.9025326e-02, -8.2517090e+00],
         [-6.1541910e+00, -2.6757658e-01, -5.4412403e+00, ...,
           1.7984511e+00,  2.9016986e+00,  7.6427579e-01],
         ...,
         [-1.1129386e+00,  7.9319181e+00,  7.7404571e-01, ...,
          -1.7145084e+01,  1.5210888e+01,  1.3812095e+01],
         [ 3.5752565e-01,  1.4212518e+00, -6.1826277e-01, ...,
          -3.4348285e+00,  5.1942883e+00,  2.1960042e+00],
         [-6.3907943e+00, -5.3237562e+00, -3.1632636e+00, ...,
           2.1118989e+00, -3.8516359e+00, -6.2463970e+00]],
        [[-7.2064867e+00, -3.6420932e+00, -1.6844990e+00, ...,
           6.4910537e-01, -4.4807429e+00, -7.8619242e+00],
         [-6.4934230e+00, -4.5477719e+00,  9.2149705e-01, ...,
           4.2846882e-01, -7.4903011e-01, -9.8737726e+00],
         [-7.2704558e+00,  9.5214283e-01, -2.0818310e+00, ...,
          -1.6958854e-01,  1.6371614e+00, -2.7756066e+00],
         ...,
         [-7.1980424e+00, -7.2074276e-01,  2.3514495e+00, ...,
          -9.7255888e+00,  2.1547556e-01,  4.3379207e+00],
         [-6.7656651e+00,  6.3100419e+00, -7.8286257e+00, ...,
          -5.1035576e+00, -1.3960669e+00,  2.3991609e+00],
         [-7.0669832e+00, -1.2582588e-01, -5.3176193e+00, ...,
           3.4836166e+00, -2.4024684e+00, -6.0632706e+00]],
        ...,
        [[-7.3400059e+00, -3.1168675e+00, -1.9545169e+00, ...,
           1.0936095e+00, -1.5736668e+00, -9.5641651e+00],
         [-2.9115820e+00, -4.7334772e-01,  2.6805878e-01, ...,
           8.3148491e-01, -1.2751791e+00, -5.5142212e+00],
         [ 1.2365078e+00,  1.0945862e+01, -4.9259267e+00, ...,
           1.9169430e+00,  5.1151342e+00,  4.9710069e+00],
         ...,
         [-2.2321188e+00,  8.8735223e-02, -7.6890874e+00, ...,
          -3.1269640e-01,  7.3404179e+00, -7.2507386e+00],
         [-2.2741010e+00, -6.5992510e-01,  4.0761769e-01, ...,
           1.8645943e+00,  4.0359187e+00, -7.7996893e+00],
         [ 5.5672646e-02, -1.4715804e+00, -1.9753509e+00, ...,
           2.5039923e+00, -1.0506821e-01, -6.5183282e+00]],

        [[-8.3111782e+00, -4.6992331e+00, -3.1351955e+00, ...,
           1.8569698e+00, -1.1717710e+00, -8.5070782e+00],
         [-4.7671299e+00, -2.5072317e+00,  2.9760203e+00, ...,
           2.9142296e+00,  3.2271760e+00, -4.7557964e+00],
         [ 5.5070686e-01,  5.3218126e-02, -2.1629403e+00, ...,
           8.8359457e-01,  3.1481497e+00, -2.1769693e+00],
         ...,
         [-3.7305963e+00, -1.2512873e+00,  2.0231385e+00, ...,
           4.4094267e+00,  3.0268743e+00, -9.6763916e+00],
         [-5.4271636e+00, -4.6796727e+00,  5.7922940e+00, ...,
           3.6725988e+00,  5.2563481e+00, -8.1707211e+00],
         [-1.2138665e-02, -3.6983132e+00, -6.4367266e+00, ...,
           6.8217549e+00,  5.7782011e+00, -5.4132147e+00]],

        [[-5.0323372e+00, -3.3903065e+00, -2.7963824e+00, ...,
           3.9016938e+00,  1.4906535e+00, -2.1907964e+00],
         [-7.7795396e+00, -5.7441168e+00,  3.4615259e+00, ...,
           1.4764800e+00, -2.9045539e+00, -4.4136987e+00],
         [-7.2599754e+00, -3.4636111e+00,  4.3936129e+00, ...,
           1.9856967e+00, -1.0856767e+00, -5.7980385e+00],
         ...,
         [-6.1726952e+00, -3.9608026e+00,  5.5742388e+00, ...,
           4.9396091e+00, -2.8744078e+00, -8.3122082e+00],
         [-1.3442982e+00, -5.5807371e+00,  4.7524319e+00, ...,
           5.0170369e+00,  2.9530718e+00, -7.1846304e+00],
         [-1.7616816e+00, -6.7234058e+00, -8.3512306e+00, ...,
           4.1365266e+00, -2.8818092e+00, -2.9208889e+00]]]],
      dtype=float32)

同时

m_m.predict(data)

输出

array([[[[ -7.836284  ,  -2.3029385 ,  -3.6463926 , ...,  -1.104739  ,
           12.992413  ,  -6.7326055 ],
         [-11.714638  ,  -2.161682  ,  -2.0715065 , ...,  -0.0467519 ,
            6.557784  ,  -2.7576606 ],
         [ -8.029486  ,  -4.068902  ,  -4.6803293 , ...,   7.022674  ,
            7.741771  ,  -1.874607  ],
         ...,
         [-11.229774  ,  -5.3050747 ,   2.807798  , ...,   1.1340691 ,
            4.3236184 ,  -5.2162905 ],
         [-11.458603  ,  -6.2387724 ,   0.25091058, ...,   1.0305461 ,
            5.9631624 ,  -6.284294  ],
         [ -8.663513  ,  -1.8256164 ,  -3.0079443 , ...,   5.9437366 ,
            7.0928698 ,  -1.0781381 ]],

        [[ -4.362539  ,  -2.8450599 ,  -3.1030283 , ...,  -1.5129573 ,
            2.2504683 ,  -8.414198  ],
         [ -6.308961  ,  -4.99597   ,  -3.8596241 , ...,   4.2793174 ,
            2.7787375 ,  -5.9963284 ],
         [ -4.8252788 ,  -1.5710263 ,  -6.083002  , ...,   4.856139  ,
            2.9387665 ,   0.29977918],
         ...,
         [ -0.8481703 ,   5.348722  ,   2.3885899 , ..., -19.35567   ,
           13.1428795 ,  12.364189  ],
         [ -1.8864173 ,  -3.7014763 ,  -2.5292692 , ...,  -3.6618025 ,
            4.3906307 ,   0.03934002],
         [ -6.0526505 ,  -5.504422  ,  -3.8778243 , ...,   4.3741727 ,
            1.0135782 ,  -5.1025114 ]],

        [[ -6.7328253 ,  -1.5671132 ,   0.16782492, ...,  -2.5069456 ,
            1.4343324 ,  -8.59162   ],
         [ -7.5468965 ,  -5.6893063 ,   0.13871288, ...,   0.22174302,
            1.1608338 ,  -8.77916   ],
         [ -5.940791  ,   1.1769392 ,  -4.5080614 , ...,   3.5371704 ,
            2.4181929 ,  -2.7893126 ],
         ...,
         [ -9.490874  ,  -2.3575358 ,   2.5908213 , ..., -18.813345  ,
           -3.4546187 ,   4.8375816 ],
         [ -5.1123285 ,   3.3766522 , -10.71935   , ...,  -5.8476105 ,
           -3.5569503 ,   0.6331433 ],
         [ -6.2075157 ,   0.4942119 ,  -7.044799  , ...,   5.191918  ,
            2.7723277 ,  -4.5243273 ]],

        ...,

        [[ -7.06453   ,  -1.3950944 ,  -0.37429178, ...,  -0.11883163,
            0.22527158,  -9.231563  ],
         [ -4.0204725 ,  -3.6592636 ,   0.15709507, ...,   1.7647433 ,
            4.6479545 ,  -3.8798246 ],
         [  0.75817275,   9.890637  ,  -7.069035  , ...,   2.995041  ,
            6.8453026 ,   6.028713  ],
         ...,
         [ -1.5892754 ,   2.119719  , -10.078391  , ...,  -2.546938  ,
            6.5255003 ,  -6.749384  ],
         [ -3.2769198 ,  -0.46709523,  -2.1529863 , ...,   1.8028917 ,
            7.2509494 ,  -7.5441256 ],
         [ -1.2531447 ,   0.96327865,  -1.0863694 , ...,   2.423694  ,
           -1.1047542 ,  -6.4944725 ]],

        [[-10.218704  ,  -2.5448627 ,  -0.6002845 , ...,   0.80485874,
            2.7691112 ,  -7.374723  ],
         [ -8.354421  ,  -5.461962  ,   5.2284613 , ...,   0.5315646 ,
            5.701563  ,  -4.0477304 ],
         [ -2.7866952 ,  -5.8492465 ,  -1.5627437 , ...,   1.9490132 ,
            4.0491743 ,  -2.7550128 ],
         ...,
         [ -4.5389686 ,  -3.2624135 ,   0.7429285 , ...,   2.5953412 ,
            3.8780956 ,  -8.652936  ],
         [ -5.704813  ,  -3.730238  ,   4.87866   , ...,   2.6826556 ,
            4.8833456 ,  -6.8225956 ],
         [ -0.16680491,  -0.4325713 ,  -4.7689047 , ...,   8.588567  ,
            6.786765  ,  -4.7118473 ]],

        [[ -1.4958351 ,   2.151188  ,  -4.1733856 , ...,  -1.891511  ,
           12.969635  ,  -2.5913832 ],
         [ -7.6865544 ,   0.5423928 ,   6.2699823 , ...,  -2.4558625 ,
            6.1929445 ,  -2.7875526 ],
         [ -6.995783  ,   2.609788  ,   5.6196365 , ...,  -0.6639404 ,
            5.7171726 ,  -3.7962272 ],
         ...,
         [ -3.6628227 ,  -1.3322173 ,   4.7582774 , ...,   2.122392  ,
            3.1294663 ,  -8.338194  ],
         [ -3.0116327 ,  -1.322252  ,   4.802135  , ...,   1.9731755 ,
            8.750839  ,  -6.989321  ],
         [  2.3386476 ,  -2.4584374 ,  -5.9336634 , ...,   0.48920852,
            3.540884  ,  -2.9136944 ]]]], dtype=float32)

这显然不是因为浮动舍入,因为输出完全不同。我不明白为什么。请帮忙

【问题讨论】:

  • 可能权重没有建好,可以在调用 predict 方法后做同样的权重检查吗?
  • @TouYou 感谢cmets。我检查了预测方法后的权重,这两个模型是相同的。事实上,这两个模型是从训练好的模型中加载的,所以我相信应该已经建立了权重。
  • 这两个模型是互相复制的吗?它们使用相同的激活函数吗?
  • 是的,它们是彼此的副本。我才发现原因。有 BN 层,虽然我将它们设置为不可训练,但 BN 层的权重发生了变化。

标签: tensorflow keras imagenet


【解决方案1】:

我通过逐层提取找到原因。模型中有 BatchNormalizing 层,权重发生了变化,尽管我将它们设置为不可训练。

【讨论】:

    猜你喜欢
    • 2018-02-17
    • 1970-01-01
    • 2021-12-18
    • 2019-01-02
    • 2021-09-06
    • 1970-01-01
    • 2021-01-25
    • 2019-09-04
    • 1970-01-01
    相关资源
    最近更新 更多