【问题标题】:Unittest the pytorch forward function对pytorch forward函数进行单元测试
【发布时间】:2021-01-03 03:56:02
【问题描述】:

我想在 Pytorch 中对我的网络模型的重写前向功能进行单元测试。所以我用 setUp 方法加载了我的模型(从 Zoo 预训练),加载了一个种子并创建了一些随机批次。在我的方法 testForward 中,我测试了 forward 对 shape 和 numel 的结果,但我还想检查一个 aapears 为 0 的特定值。我对此并不担心,所以也检查了 setUp 中的参数,这似乎不是为 0。

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8, pretrained=True)
        torch.manual_seed(0)
        self.x = torch.rand((4, 3, 45, 45))
        for param in self.model.parameters():
            print(param.data)

    def testForward(self):
        self.assertEqual(self.model.forward(self.x).shape.numel(), 64800)
        self.assertEqual(str(self.model.forward(self.x).shape), 'torch.Size([4, 8, 45, 45])')
        print(self.model.named_parameters)


if __name__ == "__main__":
    unittest.main()

所以我的问题是:前向返回张量的 sahpe 是我所期望的,但为什么这个张量完全为零?我预计至少有几个值。

导入的模型基于 VGG16 网络,在 ConvLayer 4、8 和 16 之后加分。如果需要,我还可以提供模型代码。

【问题讨论】:

  • 有什么问题?

标签: python pytorch vgg-net


【解决方案1】:

好的,在修补和调试转发功能后,我得到以下解释:

有关架构的一些信息

如果您学习 Andrew Ng 或其他人的课程,您将学会不要将权重初始化为相同的值,例如“0”。这就是 FCN 原始论文的作者所做的和他们所说的,因为它不会改变性能或不会产生更快的收敛 (FCN-Paper)。

我的解决方案

因此,出于测试目的,我在测试模块中初始化以播种随机值,我可以对其进行测试:

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8, pretrained=True)
        torch.manual_seed(0)
        # instead of zero init for score tensors use random init
        self.model.score_fr[6].weight.data.random_()
        self.model.score_fr[6].bias.data.random_()
        self.model.score_pool3.weight.data.random_()
        self.model.score_pool3.bias.data.random_()
        self.model.score_pool4.weight.data.random_()
        self.model.score_pool4.bias.data.random_()
        self.x = torch.rand((4, 3, 45, 45))

    def testForward(self):
        self.assertEqual(
            self.model.forward(self.x).shape.numel(), 64800)
        self.assertEqual(
            list(self.model.forward(self.x).shape), [4, 8, 45, 45])
        self.assertEqual(
            float(self.model.forward(self.x)[3][4][44][4]), 2277257216.0))

if __name__ == "__main__":
    unittest.main()

【讨论】:

    猜你喜欢
    • 2013-05-22
    • 1970-01-01
    • 2019-05-16
    • 2023-03-26
    • 2014-02-08
    • 1970-01-01
    • 2019-04-15
    • 2018-01-15
    • 2019-05-19
    相关资源
    最近更新 更多