【问题标题】:How does PyTorch module do the back propPyTorch 模块如何做 back prop
【发布时间】:2018-09-10 16:49:44
【问题描述】:

按照extending PyTorch - adding a module 上的说明进行操作时,我注意到在扩展Module 时,我们实际上并不需要实现后向功能。我们唯一需要做的就是在 forward 函数中应用 Function 实例,PyTorch 可以在执行 back prop 时自动调用 Function 实例中的 back 实例。这对我来说似乎很神奇,因为我们甚至没有注册我们使用的 Function 实例。我查看了源代码,但没有发现任何相关内容。任何人都可以指出所有这些实际发生的地方吗?

【问题讨论】:

    标签: python python-3.x metaprogramming pytorch


    【解决方案1】:

    不必实现backward() 是 PyTorch 或任何其他深度学习框架如此有价值的原因。事实上,实现backward() 应该只在您需要弄乱网络梯度的非常特殊的情况下完成(或者当您创建无法使用 PyTorch 的内置函数表达的自定义函数时)。

    PyTorch 使用计算图来计算后向梯度,该计算图会跟踪在前向传递期间已完成的操作。在Variable 上完成的任何操作都会隐式地在此处注册。然后就是从调用它的变量向后遍历图,并应用导数链式法则来计算梯度。

    PyTorch 的About 页面对图表及其一般工作方式进行了很好的可视化。如果您想了解更多详细信息,我还建议您在 Google 上查找计算图和 autograd 机制。

    编辑:所有这些发生的源代码将在 PyTorch 代码库的 C 部分中,其中实现了实际图形。经过一番挖掘,我找到了this

    /// Evaluates the function on the given inputs and returns the result of the
    /// function call.
    variable_list operator()(const variable_list& inputs) {
        profiler::RecordFunction rec(this);
        if (jit::tracer::isTracingVar(inputs)) {
            return traced_apply(inputs);
        }
        return apply(inputs);
    }
    

    因此,在每个函数中,PyTorch 首先检查其输入是否需要跟踪,然后按照 here 的实现执行 trace_apply()。您可以看到正在创建的节点并附加到图中:

    // Insert a CppOp in the trace.
    auto& graph = state->graph;
    std::vector<VariableFlags> var_flags;
    for(auto & input: inputs) {
        var_flags.push_back(VariableFlags::of(input));
    }
    auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
    // ...
    for (auto& input: inputs) {
        this_node->addInput(tracer::getValueTrace(state, input));
    }
    graph->appendNode(this_node);
    

    我在这里最好的猜测是每个 Function 对象在执行时都会注册自己及其输入(如果需要)。每个非函数调用(例如 variable.dot())都只是遵循相应的函数,所以这仍然适用。

    注意:我不参与 PyTorch 的开发,也绝不是其架构方面的专家。欢迎任何更正或补充。

    【讨论】:

    • 是的,但是如果我们要扩展 autograd,我们需要定义 backward 函数,我很困惑它是如何在变量实例中注册的。我猜这是一个元类技巧,但不确定他们在源代码中的哪个位置做到了。
    【解决方案2】:

    也许我不正确,但我有不同的看法。

    后向函数被定义并被前向函数调用。

    例如:

    #!/usr/bin/env python
    # encoding: utf-8
    
    ###############################################################
    # Parametrized example
    # --------------------
    #
    # This implements a layer with learnable weights.
    #
    # It implements the Cross-correlation with a learnable kernel.
    #
    # In deep learning literature, it’s confusingly referred to as
    # Convolution.
    #
    # The backward computes the gradients wrt the input and gradients wrt the
    # filter.
    #
    # **Implementation:**
    #
    # *Please Note that the implementation serves as an illustration, and we
    # did not verify it’s correctness*
    
    import torch
    from torch.autograd import Function
    from torch.autograd import Variable
    
    from scipy.signal import convolve2d, correlate2d
    from torch.nn.modules.module import Module
    from torch.nn.parameter import Parameter
    
    
    class ScipyConv2dFunction(Function):
        @staticmethod
        def forward(ctx, input, filter):
            result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
            ctx.save_for_backward(input, filter)
            return input.new(result)
    
        @staticmethod
        def backward(ctx, grad_output):
            input, filter = ctx.saved_tensors
            grad_output = grad_output.data
            grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
            grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
    
            return Variable(grad_output.new(grad_input)), \
                Variable(grad_output.new(grad_filter))
    
    
    class ScipyConv2d(Module):
    
        def __init__(self, kh, kw):
            super(ScipyConv2d, self).__init__()
            self.filter = Parameter(torch.randn(kh, kw))
    
        def forward(self, input):
            return ScipyConv2dFunction.apply(input, self.filter)
    
    ###############################################################
    # **Example usage:**
    
    module = ScipyConv2d(3, 3)
    print(list(module.parameters()))
    input = Variable(torch.randn(10, 10), requires_grad=True)
    output = module(input)
    print(output)
    output.backward(torch.randn(8, 8))
    print(input.grad)
    

    在本例中,后向函数由 ScipyConv2dFunction 函数定义。

    ScipyConv2dFunction 被 forward 函数调用。

    我说的对吗?

    【讨论】:

    • 你还没有走到那一步。你说的是真的,但你还需要考虑 PyTorch 将已知的后向函数链接在一起。
    • 是的,这正是我感到困惑的地方。为什么简单地调用ScipyConv2dFunction.apply函数也会注册backward函数。我猜这是一个元类技巧,但不知道它到底发生在哪里。
    猜你喜欢
    • 1970-01-01
    • 2020-10-02
    • 2020-10-04
    • 2018-04-29
    • 2019-12-29
    • 2020-09-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多