【问题标题】:What is the correct way to use a PyTorch Module inside a PyTorch Function?在 PyTorch 函数中使用 PyTorch 模块的正确方法是什么?
【发布时间】:2020-10-24 17:22:56
【问题描述】:

我们有一个自定义的torch.autograd.Functionz(x, t),它以一种不适合直接自动微分的方式计算输出y,并计算了运算相对于其输入xt的雅可比行列式,所以我们可以实现backward方法。

但是,该操作涉及对神经网络进行多次内部调用,我们现在已将其实现为 torch.nn.Linear 对象的堆栈,包装在 net 中,torch.nn.Module。在数学上,这些是由t 参数化的。

有什么方法可以让net 本身成为zforward 方法的输入?然后,我们将从我们的 backward 返回上游梯度 Dy 和参数 Jacobia dydt_i 的产品列表,其中每个参数 tinet 的子级(除了 @987654340 @,虽然x是数据,不需要梯度累积)。

或者我们真的需要采用t(实际上是一个单独的t_i的列表),并在z.forward内部重构Linear中所有Linear层的动作net

【问题讨论】:

    标签: python neural-network pytorch automatic-differentiation


    【解决方案1】:

    我猜你可以创建一个自定义仿函数来继承 torch.autograd.Function 并使 forwardbackward 方法非静态(即删除 this example 中的 @staticmethod 以便 net 可以是一个属性你的仿函数。看起来像

    class MyFunctor(torch.nn.autograd.Function):
        def __init(net):
             self.net = net
        
         def forward(ctx, x, t):
             #store x and t in ctx in the way you find useful
             # not sure how t is involved here
             return self.net(x) 
    
         def backward(ctx, grad):
             # do your backward stuff
    
    net = nn.Sequential(nn.Linear(...), ...)
    z = MyFunctor(net)
    y = z(x, t)
    

    这将产生一个警告,表明您正在使用一种已弃用的传统方法来创建 autograd 函数(由于非静态方法),并且在反向传播后,您需要格外小心将 net 中的梯度归零。所以不是很方便,但我不知道有什么更好的方法来拥有一个有状态的 autograd 函数。

    【讨论】:

    • 不幸的是,这似乎不起作用。如果您只是像调用Module 一样尝试调用Function,则弃用会被提升为RuntimeError(至少在PyTorch 1.6 和1.7 中)。
    • 更相关的是,虽然我们没有将net 列为forward 的输入,但我们仍然需要grad 和所有雅可比行列式dz_dtheta 之间的VJP(其中每个thetanet的参数)。但是 PyTorch 只期望来自 backward 的两个输出,因为我们向 forward 提供了两个输入。
    • 情况也会变得更糟:即使您将签名写为forward(x, t, *thetas) 并因此期望backward 有2+len(thetas) VJP 输出,您实际上也需要使用 backward 中的 thetas 以获得正确的渐变。不像 TensorFlow,net 的参数不是全局变量,因为 PyTorch 是逐个运行定义的。这是我要在backward 内部解决的问题,但我认为我需要一个函数net(x, *thetas)Moduleforward 重构为thetas 的显式函数。
    • 那么我很抱歉,我没有任何更聪明的想法。您可能应该尝试在 pytorch 论坛上询问,开发人员可能有解决方案
    • 起初,我认为可能有a way 直接在Function 上调用forward 而不是__call__Module 方式)或apply(@987654357 @) 方式。但后来我意识到,如果没有 @staticmethod,PyTorch 会忽略我的 backward 实现,而只是进行正常的运算符重载。我目前的方法是使用Module.load_state_dict 来创建闭包,而不必完全丢弃Module
    猜你喜欢
    • 2021-07-26
    • 2020-12-28
    • 2020-05-08
    • 2019-12-29
    • 2019-10-11
    • 1970-01-01
    • 2021-07-08
    • 2012-01-19
    • 1970-01-01
    相关资源
    最近更新 更多