【问题标题】:pytorch nn.Module inferencepytorch nn.Module 推理
【发布时间】:2021-08-04 16:40:01
【问题描述】:

我打算学习 Pytorch。但是在这个阶段我想问一个问题,以便我可以理解我正在阅读的一些代码

当你有一个基类是nn.Module的类时说

class My_model(nn.Module)

应该如何在那里进行推理?

在我正在阅读的代码中说

tasks_output, other = my_model(data)

那不只是创建一个对象吗? (就像调用类构造函数)

在 pytorch 中,应该如何进行推理?

(当my_model设置为my_model.eval()时,作为参考我在说)

编辑:我很抱歉。我犯了将类和对象声明为一体的错误。我更正了代码

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    你混淆了__init____call__
    在您的示例中,my_model 是一个,因此调用

    my_model_instance = my_model(arguments)
    

    使用arguments 调用my_model.__init__。此调用的结果是变量my_model_instancemy_model 的新实例

    一旦您将 my_model 实例化为变量my_model_instance,您就可以在训练数据上评估模型:

    tasks_output, other = my_model_instance(data)
    

    “调用”(即在变量名后面加上括号)模型的实例导致python调用类的方法__call__
    对于从nn.Modules 派生的类,这将调用nn.Module__call__,它会执行一些pytorch 的操作,并最终调用forwardforward 方法的实现@。

    请参阅this detailed thread,了解python中__init____call__的区别。

    通常是方便的关注PEP8 Style Guide for Python Code

    类名通常应使用 CapWords 约定。

    函数名称应为小写,必要时用下划线分隔单词以提高可读性。 变量名遵循与函数名相同的约定。

    【讨论】:

    • 谢谢!这很有帮助!
    • 这是一个很好的回应!谢谢!
    【解决方案2】:

    你有例子:

    class My_model(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 4 * 4, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 4 * 4)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # Call construtor of Class 
    my_model = My_model()
    

    区分类和对象很重要。 名称的类在 Python 中以大写字母开头。

    如你所见的构造函数,它不带数据/输入参数,只有函数 forward 有一个。

    之后,为了培训,你必须:

    1. 计算带有标签的模型误差的标准。
    2. 它必须有反向传播算法的优化器

    示例:

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    

    最后,你必须通过循环需要这个元素:

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    

    在这里,您有一次反向传播迭代。

    Pytorch documentation

    如果您想考虑反向传播中的推理,您可以阅读如何使用 pytorch 创建图层以及 pytorch 如何使用签名。

    张量使用 Autograph 进行反向传播。以Pytorch documentation为例

    import torch
    
    x = torch.ones(5)  # input tensor
    y = torch.zeros(3)  # expected output
    w = torch.randn(5, 3, requires_grad=True)
    b = torch.randn(3, requires_grad=True)
    z = torch.matmul(x, w)+b
    loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
    loss.backward()
    print(w.grad)
    print(b.grad)
    

    结果给出了反向传播,其中交叉熵标准计算模型和标签的距离。张量 z 不是唯一的值矩阵,而是一个具有 w、b、x、y 的“记忆计算”的类。

    在该层中,梯度使用前向函数进行此计算,或者在必要时使用后向函数。

    最好的尊重

    【讨论】:

    • 非常感谢您的回答。在我的情况下,我不是在训练阶段,而只是在评估阶段(因此my_model.eval()。在我进行推理时的代码中,使用数据调用对象来进行推理,如my_model(data)我不明白
    【解决方案3】:

    PyTorch 中的模型是通过从 nn.Module 基类继承来定义的:

    class Model(nn.Module)
        pass
    

    然后您可以实现一个forward 方法作为推理代码。无论是用于训练还是评估,它都应该返回模型的输出。

    class Model(nn.Module)
        forward(self, x)
            return x**2
    

    一旦你有了它,你可以初始化一个新模型:

    model = Model()
    

    要使用新初始化的模型,您实际上不会直接调用forwardnn.Module 的底层结构使得您可以改为调用 __call__。它将处理对您的 forward 实现的调用。要使用它,您只需像调用函数一样调用对象:

    >>> model(2)
    4
    

    documentation page 中,您可以看到nn.Module.eval 会将模型设置为评估模式,这会影响特定层,例如批量标准化层和dropout。这些类型的层通常在训练时打开,在评估和测试时关闭。你可以把它当作

    model.eval()
    

    在进行模型评估和测试时,建议使用torch.no_grad 上下文管理器。这避免了必须保留用于梯度反向传播的激活。

    with torch.no_grad():
        out = model(x)
    

    或者作为函数/方法声明之上的装饰器:

    @torch.no_grad()
    validate():
        pass
    

    【讨论】:

    • 谢谢!我正在阅读代码,是的,定义了一个 forward 函数。所以model(2) 实际上是在调用__call__,它在调用forward?我理解正确吗?
    • 其实更准确地说:model(2) model.__call__(2) (doc for __call__)。 __call__ 方法由nn.Module 定义,将在对象上调用forward
    猜你喜欢
    • 2022-12-12
    • 1970-01-01
    • 2018-04-29
    • 2021-11-07
    • 2021-03-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-10-18
    相关资源
    最近更新 更多