rossiXYZ

[源码解析] PyTorch 流水线并行实现 (5)--计算依赖

0x00 摘要

前几篇文章我们介绍了 PyTorch 流水线并行的基本知识,自动平衡机制和切分数据等,本文我们结合论文内容来看看如何实现流水线依赖,核心就是如何建立这些小批次之间的跨设备依赖关系

流水线并行其他文章链接如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

[源码解析] 深度学习流水线并行 GPipe(3) ----重计算

[源码解析] 深度学习流水线并行之PipeDream(1)--- Profile阶段

[源码解析] 深度学习流水线并行 PipeDream(2)--- 计算分区

[源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型

[源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎

[源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块

[源码解析] 深度学习流水线并行 PipeDream(6)--- 1F1B策略

[源码解析] PyTorch 流水线并行实现 (1)--基础知识

[源码解析] PyTorch 流水线并行实现 (2)--如何划分模型

[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统

[源码解析] PyTorch 流水线并行实现 (4)--前向计算

本文图片来自论文和github源码。

0x01 前文回顾

为了更好的理解本文,我们首先看看前文之中的关键部分。

  • 原始流水线状态如下:
    • 管道并行的策略是根据分区索引 j 分配任务,以便第 j 个分区完全位于第 j 个设备中。
    • 持有模型后期部分的设备必须等待,直到持有模型早期部分的设备计算结束。

img

  • 目标流水线状态如下:

img

  • 目前问题

    • 如果分成若干个微批次,则需要强制要求 \(F_{i,j}\) 必须在 \(F_{i+1,j}\) 之前完成,以及 \(B{i,j}\) 必须在执行\(B{i-1,j}\) 之前完成
    • 后向传播的计算图是在前向传播过程中动态构造的。PyTorch既不记录正向计算图,也不维护一个梯度磁带(gradient tape),PyTorch的自动微分(autograd)引擎仅对计算图进行反向传播。这意味着自动加载引擎可能不会完全按照与正向过程相反的执行顺序运行,除非由图的结构强制执行
  • 目前难点

    • 如何在每个设备中以正确的顺序发布那些绑定到设备的任务,以避免由于Python解释器未能提前请求而延迟在设备上(与CPU异步)执行任务。[这个前文已经介绍]
    • 如何建立这些小批次之间的跨设备依赖关系
  • 实现方案

    • 如何保证正确执行顺序?torchgpipe引入了确定性时钟周期(deterministic clock-cycle),它给出了任务的总体顺序[这个前文已经介绍]
    • 如何保证计算图中的动态显式依赖关系?针对clock_cycles产生的每一个运行计划:
      • 利用 fence 函数调用“fork”和“join”,以此在向后计算图中动态创建显式后向传播依赖关系。
      • 利用 compute(schedule, skip_trackers, in_queues, out_queues) 进行计算。

因为前文已经介绍了执行顺序方案,所以本文介绍如何计算依赖。

0x02 计算依赖

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+--------------------------+---------------------------+----------------------------------+
                                          +
                                          |
                                          |
                                          v
 +------------------------------------------------------------------------------------+
 | +--------------------+         +---------------------+      +--------------------+ |
 | |Partition 1         |         |Partition 2          |      |Partition 3         | |
 | |                    |         |                     |      |                    | |
 | |      Layer 1       |    +---------> Layer 4        |      |                    | |
 | |         +          |    |    |         +           |  +------->   Layer 6      | |
 | |         |          |    |    |         |           |  |   |                    | |
 | |         v          |    |    |         |           |  |   |                    | |
 | |      Layer 2       |    |    |         |           |  |   |                    | |
 | |         +          |    |    |         v           |  |   |                    | |
 | |         |          |    |    |      Layer 5 +---------+   |                    | |
 | |         v          |    |    |                     |      |                    | |
 | |      Layer 3  +---------+    |                     |      |                    | |
 | |                    |         |                     |      |                    | |
 | +---------+----------+         +---------+-----------+      +-----------+--------+ |
 |                                                                                    |
 +------------------------------------------------------------------------------------+

为什么需要计算依赖?

  • 因为模型已经被分层,模型的不同部分拆开放到不同设备上,数据也分成微批次,所以本来模型内部是线性依赖关系,现在需要变成流水线依赖关系。因此原始计算图不能满足需求,因此需要有针对性的补充。就像上图那样,6个层被分成了三个partitions,这三个partitons 之间的依赖如何构建
  • 之前的线性依赖关系其实是在模型定义时候就基本确定了,现在则需要每次运行时候建立一个动态依赖关系。

所以针对流水线并行,torchgpipe需要自己补充一个本机跨设备伪分布式依赖关系。torchgpipe 通过在前向计算图和后向计算图做各种调整来达到目的。计算图就意味着各种依赖逻辑,依赖逻辑的补足就是依靠本节介绍的 Fork 和 Join 两个函数完成的。

这里最初有一个疑问,就是Torchgpipe怎么在不使用 PyTorch RPC 和 p2p的情况下,构建出来一个异地反向计算图。后来发现,原来是我想多了,因为Torchgpipe没有考虑到这种情况,它针对都是在同一个主机之上的GPU,不涉及异地多机器计算。

Torchgpipe 本质上还是一个进程内运行多个线程进行计算,是 DP 的替代。比如源码中就有对比如下:

### ResNet-101 Accuracy Benchmark

Batch size | torchgpipe | nn.DataParallel | Goyal et al.
---------- | ---------: | --------------: | -----------:
256        | 21.99±0.13 |      22.02±0.11 |   22.08±0.06
1K         | 22.24±0.19 |      22.04±0.24 |          N/A
4K         | 22.13±0.09 |             N/A |          N/A

再比如代码中明确提到:

If you decide not to use checkpointing at all, :class:`nn.DataParallel
<torch.nn.DataParallel>` might be more efficient than GPipe.

0x03 反向传播依赖

我们首先看看反向传播依赖,这个是论文的重点。

2.1 解析

我们还是要回忆一下前面两个图例。

图1

img

图2

img

这里需要完成两种依赖:

  • 行间依赖,就是 batch 之间的依赖,就是设备内的依赖。从图上看,就是蓝色列内的 \(F_{1,1}\) 必须在 \(F_{2,1}\)之前完成,\(B_{2,1}\) 必须在\(B_{1,1}\) 之前完成。
  • 列间依赖,就是 partitions(设备) 之间的依赖。从图上看,就是蓝色 \(F_{1,1}\) 必须在黄色 \(F_{1,2}\)之前完成,即第一个设备必须在第二个设备之前完成,而且第一个设备的输出是第二个设备的输入。

假定我们依据确定性时钟周期(deterministic clock-cycle)算法来运行一个前向传播。即使前向传播是按照在第j个设备上应该执行的顺序来执行任务 \(F_{1,j},...,F_{m,j}\) ,得到的后向传播结果计算图看起来也更像图1而非图2,

从图1上看,PyTorch 的 autograd 引擎不知道 \(B_{i+1,j}\) 必须在 \(B_{i,j}\) 之前运行,因此会打乱后向传播的时间流。因此,虚拟依赖(图2的虚线箭头)必须在前向传播中被显式绘制出来。

我们再仔细分析一下图2。图2之中,每一行都表示一个 micro-batch 在训练中的运行流,这个流的前向是由clock算法确定的。后向关系是由前向传播中自动确定完成的

现在的问题是:一个 mini-batch 被分成了4个 micro-batch,分别在不同时钟周期进入训练。就是每一列。这一列由上到下的传播也是由clock算法确定,但是反向传播(由下自上)目前是不确定的。比如最后一列中,反向传播的顺序应是:\(B_{4,1},B_{3,1},B_{2,1},B_{1,1}\)。但是这个目前从前向传播的结果来看,无法确定这个顺序。

所以需要依靠本节介绍的 Fork 和 Join 两个函数完成这个依赖关系。图中斜线表示checkpoint之中需要先有一个重计算,然后才能由下往上走

因此,torchpipe定义两个基础函数,Fork 和 Join 来表达这种依赖关系:

  • Fork 是 auto grad 函数,其把一个张量 x 映射到 pair(x, \(\phi\)),这里 \(\phi\) 是一个空张量。

  • Join 是 auto grad 函数,其把 pair(x, \(\phi\)) 映射到一个张量 x ,这里 \(\phi\) 是一个空张量。

现在,\(F_{i+1,j}\) 对于 \(F_{i,j}\) 的依赖(其在后向传播计算图中被转换为 \(B_{i,j}\) 到 $B_{i+1,j} $ 的依赖关系)可以被如下表示

所以,图中这里实线都是前向传播时候构建的,虚线是由 fork & join 构建的。

原则上,表示虚拟依赖关系的张量可以是任意的。然而,torchgpipe选择使用空张量,以消除由张量引起的任何不必要的计算,例如PyTorch中的梯度累积。

具体如下图。就是使用 Fork 和 Join 的后向计算图。图中,不同颜色对应不同的设备。箭头依据后向传播图的方向来绘制,这些联系是在前向传播中被构建的。因此,\(F^{'}_{i,j}\) 对于 \(B_{i+1,j}\) 的虚拟依赖通过 Fork 和 Join 被构建出来,用虚线表示。

2.2 基础功能

2.2.1 Function

首先,我们要看看 torch.autograd.Function 的作用。

torch.autograd.Function类实际上是一个操作函数的基础父类,这样的操作函数必须具备两个基本的过程,即前向的运算过程和反向的求导过程,

如果某些操作无法通过 PyTorch 已有的层或者是已有的方法实现不了,就需要实现一个新的方法对 PyTorch 进行拓展。当不使用自动求导机制,需要自定义求导规则的时候,就应该拓展torch.autograd.Function类。 由于pytorch不再提供自动求导机制,就要用户自己定义实现前向传播和反向传播的计算过程,这就是 "Extending torch.autograd"。

我们接下来介绍Backward Dependency 的关键算法:Fork and Join。

2.2.2 Fork

Fork 是auto grad 函数,其把一个张量 x 映射到 pair(x, \(\phi\)),这里 \(\phi\) 是一个空张量。Fork 方法就是拓展了torch.autograd.Function

def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony


class Fork(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Fork', input: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
        phony = get_phony(input.device, requires_grad=False)
        return input.detach(), phony.detach()

    @staticmethod
    def backward(ctx: 'Fork', grad_input: Tensor, grad_grad: Tensor) -> Tensor:  # type: ignore
        return grad_input

2.2.3 Join

Join 是auto grad 函数,其把 pair(x, \(\phi\)) 映射到一个张量 x ,这里 \(\phi\) 是一个空张量。Join 方法也是拓展了torch.autograd.Function

def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input


class Join(torch.autograd.Function):
    @staticmethod
    def forward(ctx: 'Join', input: Tensor, phony: Tensor) -> Tensor:  # type: ignore
        return input.detach()

    @staticmethod
    def backward(ctx: 'Join', grad_input: Tensor) -> Tuple[Tensor, None]:  # type: ignore
        return grad_input, None

2.2.4 Phony

Phony是没有空间的张量,因为它不需要任何梯度累积,所以可以在 autograd 图中构建任意的依赖。

def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
    """Gets a phony. Phony is tensor without space. It is useful to make
    arbitrary dependency in a autograd graph because it doesn't require any
    gradient accumulation.

    .. note::

        Phonies for each device are cached. If an autograd function gets a phony
        internally, the phony must be detached to be returned. Otherwise, the
        autograd engine will mutate the cached phony in-place::

            class Phonify(torch.autograd.Function):
                @staticmethod
                def forward(ctx, input):
                    phony = get_phony(input.device, requires_grad=False)
                    return phony.detach()  # detach() is necessary.

    """
    key = (device, requires_grad)

    try:
        phony = _phonies[key]
    except KeyError:
        with use_stream(default_stream(device)):
            phony = torch.empty(0, device=device, requires_grad=requires_grad)

        _phonies[key] = phony

    return phony

2.2.5 detach

在代码中,经常可以见到 detach 的使用,这个从注释可以看出来,是为了解决 PyTorch 的一个bug。

    # A Python autograd function might fail with this error:
    #
    #   RuntimeError: Returning Variables sharing storage with other Variables
    #   that require grad is not supported in Python functions. Please submit a
    #   feature request if you hit this error.
    #
    # It doesn't look like an essential restriction. But it happens on the
    # current PyTorch version. To avoid it, we should detach the tensor before
    # returning by identity autograd functions, such as Wait, Fork, and Join.
    #

2.3 使用

在 Pipeline 之中我们可以看到具体的使用方法,fence 方法(省略部分代码)利用 depend 来构建后向传播的依赖关系,确保 batches[i-1] 在 batches[i] 之后完成。

    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 在这里建立了后向传播依赖关系
                
            next_stream = copy_streams[j][i]

            for prev_j, ns, name in skip_layout.copy_policy(j):
                prev_stream = copy_streams[prev_j][i]
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0:
                prev_stream = copy_streams[j-1][i]
                copy(batches[i], prev_stream, next_stream)                

具体 depend 代码如下:

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)

我们结合示例代码把传入的参数赋值一下,重新把方法解释如下,这样大家就可以更好的理解。

def depend(batches[i-1]: Batch, batches[i]: Batch) -> None:
    batches[i-1][0], phony = fork(batches[i-1][0])
    batches[i][0] = join(batches[i][0], phony)

具体逻辑如下,通过 phony 完成了一个桥接,即在正向传播之中,batches[i] 依赖 batches[i-1] 的执行结果

      +----------------+          +--------------+
      |                |          |              |
      |  batches[i-1]  |          |  batches[i]  |
      |                |          |              |
      +----------+-----+          +-----+--------+
                 |                      |
                 |                      |
                 |                      |
                 v                      v
+--------------------------------------------------------+
| depend         |                      |                |
|                |                      |                |
|                |                      |                |
|                v                      |                |
|        +-----------------------+      |                |
|        | fork  |               |      |                |
|        |       |    get_phony  |      |                |
|        |       |        +      |      |                |
|        |       |        |      |      |                |
|        |       |        |      |      |                |
|        +-----------------------+      |                |
|                |        |             |                |
|                |        |             |                |
|                |        |             |                |
|                v        v             |                |
|    +-----------+--+  +--+-----+       |                |
|    |              |  |        |       |                |
|    | batches[i-1] |  | phony  |       |                |
|    |              |  |        |       |                |
|    +--------------+  +--+-----+       |                |
|                         |             |                |
|                         |             |                |
|                         v             v                |
|                      +--+------------------+           |
|                      |Join            |    |           |
|                      |                |    |           |
|                      |                |    |           |
|                      |                v    |           |
|                      +---------------------+           |
|                                       |                |
|                                       |                |
|                                       |                |
|                                       v                |
|                                 +-----+------+         |
|                                 |            |         |
|                                 | batches[i] |         |
|                                 |            |         |
|                                 +------------+         |
|                                                        |
+--------------------------------------------------------+

我们把多个 batches 联合起来看看,这样就能看出来一个依赖链条。

                  +----------------------------------------------------------+
                  | depend                                                   |
                  |                                                          |
                  | +------------+                                           |
 +-------------   | |fork        |     +-----------+                         |
 |            |   | |            |     |           |                         |
 |batches[i]  +----------------------> | batches[i]|                         |
 |            |   | |            |     |           |                         |
 +-------------   | |            |     +-----------+                         |
                  | |            |             +-------+                     |
                  | |            +-----------> | Join  |                     |
                  | |            |             |       |                     |
                  | +------------+             |       |                     |
 +-------------   |                            |       |    +--------------+ |
 |            |   |                            |       |    |              | |
 |batches[i+1]+-------------------------------------------->+ batches[i+1] | |
 |            |   |                            |       |    |              | |
 +---------+---   |                            |       |    +--------------+ |
           |      |                            +-------+                     |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |      +----------------------------------------------------------+
           |      | depend                                                   |
           |      |                                                          |
           |      | +-------------+                                          |
           |      | |fork         |     +------------+                       |
           |      | |             |     |            |                       |
           +--------------------------> |batches[i+1]|                       |
                  | |             |     |            |                       |
                  | |             |     +------------+                       |
                  | |             |           +-------+                      |
                  | |             +---------> |Join   |                      |
                  | +-------------+           |       |                      |
+------------+    |                           |       |     +-------------+  |
|            |    |                           |       |     |             |  |
|batches[i+2]+--------------------------------------------> | batches[i+2]|  |
|            |    |                           |       |     |             |  |
+----------+-+    |                           |       |     +-------------+  |
           |      |                           +-------+                      |
           |      |                                                          |
           |      +----------------------------------------------------------+
           |
           |      +-----------------------------------------------------------+
           |      | depend                                                    |
           |      |                                                           |
           +----------------------------->    ......                          |
                  |                                                           |
                  |                                                           |
                  +-----------------------------------------------------------+

这样,上图就是前向计算图,于是在后向传播之中,batches[i] 就 必须在 batches[i-1] 之前完成了

我们再结合论文的图来看看。

本来示例代码中是:

depend(batches[i-1], batches[i])

为了和论文中的图对应,我们修改为:

depend(batches[i], batches[i+1])

depend 代码也变化为:

def depend(batches[i]: Batch, batches[i+1]: Batch) -> None:
    batches[i][0], phony = fork(batches[i][0])
    batches[i+1][0] = join(batches[i+1][0], phony)

对应下图,就是在后向传播计算图之中 batches[i+1] 通过一个join, 一个fork,排在了 batches[i] 前面,就是下面大箭头所示,具体细化一下:

  • 从这个图上,PyTorch 的 autograd 引擎不知道 \(B_{i+1,j}\) 必须在 \(B_{i,j}\) 之前运行,因此会打乱后向传播的时间流。因此,虚拟依赖(前面图的虚线箭头)必须在前向传播中被显式绘制出来。

  • 图上的实线箭头依据后向传播图的方向来绘制,这些联系是在前向传播中被构建的。就是说,对于 \({Batch}_i\) 来说,其反向传播顺序是固定的。就是上面一行内顺序是固定的,下面一行内顺序也是固定的

  • 但是,上下两行之间的顺序是不可知的,需要用虚线来保证,就是用 Join & Fork 来保证。

0x03 正向传播依赖

我们回头再来看正向依赖。因为正向传播的部分目的就是完成反向传播依赖,而目前反向传播只完成了行之间的依赖,列之间的依赖没有完成,我们现在补全

列之间的依赖就是设备之间的依赖,即前一个设备的输出是后一个设备的输入

3.1 分割模型

首先还是需要回顾下如何切分模型,从 split_module 可以看到,

GPipe 的 partitions 成员变量是 nn.ModuleList 类型。nn.ModuleList是一个容器,其储存不同 module,并自动将每个 module 的 parameters 添加到网络中。但是nn.ModuleList 并没有定义一个网络,而只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序,网络的执行顺序是根据 forward 函数来决定的。

def split_module(module: nn.Sequential,
                 balance: Iterable[int],
                 devices: List[torch.device],
                 ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:

    balance = list(balance)

    j = 0
    partitions = []
    layers: NamedModules = OrderedDict()

    for name, layer in module.named_children(): # 遍历模型包含的层
        layers[name] = layer # 把新的层加入到数组中

        if len(layers) == balance[j]: # 如果数组大小等于balance[j],就是达到了device j应该包含的层数
            # Group buffered layers as a partition.
            partition = nn.Sequential(layers) # 把层数组组合成一个sequential module

            device = devices[j]
            partition.to(device) # 把层放置到相关设备之上

            partitions.append(partition) # 这个新module加入到分区数组中

            # Prepare for the next partition.
            layers.clear()
            j += 1 # 去下一个device看看

    partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
    del devices[j:]

    return partitions, balance, devices

随之而来问题就是:partition内部可以用Sequential来进行一系列的前向操作,但是如何配置partitions 之间的执行顺序?

+-----------------------------------------------------------------------------------------+
|                                                                                         |
| Layer 1 +--->  Layer 2 +-----> Layer 3 +----->  Layer 4 +-----> Layer 5  +---> Layer 6  |
|                                                                                         |
+-----------------------------------------+-----------------------------------------------+
                                          |
                                          |
                                          |
                                          v
+-----------------------------------------------------------------------------------------+
| +--------------------+           +---------------------+         +--------------------+ |
| |Partition 1         |           |Partition 2          |         |Partition 3         | |
| |                    |   ???     |                     |         |                    | |
| |      Layer 1       |     +----------> Layer 4        |   ???   |                    | |
| |         +          |     |     |         +           |     +------->   Layer 6      | |
| |         |          |     |     |         |           |     |   |                    | |
| |         v          |     |     |         |           |     |   |                    | |
| |      Layer 2       |     |     |         |           |     |   |                    | |
| |         +          |     |     |         v           |     |   |                    | |
| |         |          |     |     |      Layer 5 +------------+   |                    | |
| |         v          |     |     |                     |         |                    | |
| |      Layer 3  +----------+     |                     |         |                    | |
| |                    |           |                     |         |                    | |
| +--------------------+           +---------------------+         +--------------------+ |
|                                                                                         |
+-----------------------------------------------------------------------------------------+

3.2 建立依赖

我们还是从论文中入手。假定我们有一个神经网络,其由一系列子网络构成。我们假定这些子网络是 \(f^1,...,f^n\),其参数分别是 \(\theta^1,...,\theta^n\),则整个网络是:

参数是 \(\theta = (\theta^1,...,\theta^n)\),为了清楚起见,我们称 \(f^j\) 表示 f 的第 j 个分区,并假设分区的参数是相互不相交的。

在训练网络时,基于梯度的方法(如随机梯度下降法)需要在给定小批量训练数据 x 和相应损失之后,计算网络的输出结果f(x)。以及损失相对于网络参数 \(\theta\) 的梯度g。这两个阶段分别称为向前传播和向后传播。

既然 f 由其 L 层 子模块 (\(f^L, f^{L-1},...f^1\)) 顺序组成,那么前向传播\(f(x)\) 可以通过如下方式计算:让 \(x^0=x\)(就是输入x),然后顺序应用每一个 partition,即 \(x^j = f^j (x^{j-1})\),这里 $ j = 1, ..., L$。就是 \(f(x)\) 可以表示为 :

\[f(x) = f^L(f^{L-1}(f^{L-2}(... f^1(x)))) \]

于是我们知道了,前向传播的顺序是由 \(f(x) = f^L(f^{L-1}(f^{L-2}(... f^1(x))))\) 来确定的

我们可以针对代码,进一步解析,看看如何实施partitions之间的顺序依赖。

    def run(self) -> None:
        """Runs pipeline parallelism.

        It modifies the given batches in place.

        """
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        skip_layout = self.skip_layout

        m = len(batches)
        n = len(partitions)

        skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

        with spawn_workers(devices) as (in_queues, out_queues):
            for schedule in clock_cycles(m, n): # 这里使用,给出了执行序列计划,后续按照这个来执行
                self.fence(schedule, skip_trackers)
                self.compute(schedule, skip_trackers, in_queues, out_queues)

解析的目标是 for schedule in clock_cycles(m, n) 这个 for 循环,其:

  • 针对clock_cycles产生的每一个运行计划:
    • 利用 fence(schedule, skip_trackers) 构建后向传播依赖关系。
    • 利用 compute(schedule, skip_trackers, in_queues, out_queues) 进行计算。

现在我们完成了两步:

  1. 确定性时钟周期算法给定了前向传播的执行顺序,我们只要按照 clock_cycles 方法提供的计划一一运行即可
  2. fence 方法通过调用 join 和 fork,我们做到了在后向传播之中,batches[i] 就 必须在 batches[i-1] 之前完成了,即 \(B_{i+1,j}\) 必须在 \(B_{i,j}\) 之前运行。

对于我们的图来说,第二步就是完成了下图的列依赖。

我们的问题是:怎么通过这个 for 循环,做到 \(B_{i,{j+1}}\) 必须在 \(B_{i,j}\) 之前运行?,即怎么安排反向传播逐次运行?就是怎么完成行内的依赖?

这就要通过 compute 的源码进行分析。重点说明的是:

  • batches[i] 这里是会变化的,比如 batches[0] 在经过 partitions[j] 的计算之后,会变成 batches[0][j]
  • 对于 compute 方法,关键就是在最底部的代码 batches[i] = batch。就是把 第 j 个device 对 第 i 个 batch 的计算结果 赋值到 batches[i],赋值之后,batches[i]就是 batches[i][j],这样,在下次计算时候,构建的就是 F[i, j+1], 下一次 fence 之中的 depend 操作,就是针对 batches[i, j+1]
  • 因此,在前向计算图上,通过这个赋值操作, batches[i, j+1] 就依赖 batches[i, j],所以反向计算时候,batches[i, j + 1] 就必须在 batches[i, j] 之前完成
    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        n = len(partitions)
        streams = [current_stream(d) for d in devices]
  
        for i, j in schedule: # 针对 schedule 之中的每一对 i,j
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)

            # Determine whether checkpointing or not.

            if checkpoint:
							# 忽略
            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition) # 前向计算,计算以 partition为单位计算,partition内部的层是顺序计算,由 Sequential保证。

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task) # 让 worker计算

        for i, j in schedule:
            ok, payload = out_queues[j].get() # 获取 worker 的前向计算结果,就是 第 j 个device 对 第 i 个 batch 的计算结果

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)

            # 第 j 个device 对 第 i 个 batch 的计算 就是 F[i,j]

            batches[i] = batch # 这里是关键,就是把 第 j 个device 对 第 i 个 batch 的计算结果 赋值到 batches[i],batches[i]就是 batches[i][j],在下次计算时候,构建的就是 F[i,j+1], 下一次 fence 之中的 depend 操作,就是针对 batches[i,j+1]

关于这个赋值操作,其对应的grad_fn 是 PermuteBackward,比如:

a = torch.tensor([2., 3.], requires_grad=True)
c = a
c.backward(gradient=external_grad)
print(c)

具体是:

c = {Tensor: 2} tensor([2., 3.], requires_grad=True)
  T = {Tensor: 2} tensor([2., 3.], grad_fn=<PermuteBackward>)

现在,我们把下图进行升级。

                 +-------------------------------------------------------------------+
                 | depend                                                            |
                 |                                                                   |
                 | +---------------+                                                 |
                 | |fork           |                                                 |
+-------------   | |               |     +-----------+                               |
|            |   | |               |     |           |                               |
|batches[i]  +-------------------------> | batches[i]|                               |
|            |   | |               |     |           |                               |
+-------------   | |               |     +-----------+                               |
                 | |               |                                                 |
                 | |               |                                                 |
                 | |               |     +--------+    +-------+                     |
                 | |  get_phony +------> |        +--->+ Join  |                     |
                 | |               |     | phony  |    |       |                     |
                 | +---------------+     |        |    |       |                     |
                 |                       +--------+    |       |                     |
                 |                                     |       |                     |
+-------------   |                                     |       |    +--------------+ |
|            |   |                                     |       |    |              | |
|batches[i+1]+----------------------------------------------------->+ batches[i+1] | |
|            |   |                                     |       |    |              | |
+-------------   |                                     |       |    +--------------+ |
                 |                                     +-------+                     |
                 |                                                                   |
                 +-------------------------------------------------------------------+

我们进行横向拓展,得到如下,即一个batch 被分成两个小批次: batches[i],batches[i+1] ,它们在两个设备 partitions[j],partitions[j + 1] 之上流水线,这样行和列都有反向传播的依赖。

                                 F[i,j]                                                                            F[i,j+1]

                    +------------------------------------------------+                            +-----------------------------------------------+
                    | partitions[j]                                  |                            |  partitions[j+1]                              |
                    |                                                |                            |                                               |
                    | +--------------------+   +------------------+  |                            | +-------------------+   +------------------+  |
                    | |fence               |   | compute          |  |                            | | fence             |   | compute          |  |
                    | |                    |   |                  |  |                            | |                   |   |                  |  |
+--------------+    | |  +--------------+  |   |  +------------+  |  |     +-----------------+    | |   +-------------+ |   |  +------------+  |  |       +-----------------+
|              |    | |  | depend       |  |   |  |forward     |  |  |     |                 |    | |   | depend      | |   |  |forward     |  |  |       |                 |
|  batches[i]  +---------------------------------------------------------> | batches[i][j]   +----------------------------------------------------------> | batches[i][j+1] |
|              |    | |  |              |  |   |  |            |  |  |     |                 |    | |   |             | |   |  |            |  |  |       |                 |
+--------------+    | |  |              |  |   |  |            |  |  |     +-----------------+    | |   |             | |   |  |            |  |  |       +-----------------+
                    | |  |              |  |   |  +------------+  |  |                            | |   |             | |   |  +------------+  |  |
                    | |  |              |  |   |                  |  |                            | |   |             | |   |                  |  |
+--------------+    | |  |              |  |   +------------------+  |     +-----------------+    | |   |             | |   +------------------+  |       +-------------------+
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
|  batches[i+1]+---------------------------------------------------------> | batches[i+1][j] +----------------------------------------------------------> | batches[i+1][j+1] |
|              |    | |  |              |  |                         |     |                 |    | |   |             | |                         |       |                   |
+--------------+    | |  +--------------+  |                         |     +-----------------+    | |   +-------------+ |                         |       +-------------------+
                    | |                    |                         |                            | |                   |                         |
                    | +--------------------+                         |                            | +-------------------+                         |
                    +------------------------------------------------+                            +-----------------------------------------------+

手机如下:

0x04 总结

下图 $ m = 4, n = 3$。即,模型被分成3个子网络,小批次被分割成 4个微批次。F 和 B 的下标是 (m, n)。

img

如上图,这里需要完成两种依赖:

  • 行间依赖,就是 batch 之间的依赖,就是设备内的依赖。从图上看是虚线,就是 \(F_{1,1}\) 必须在 \(F_{2,1}\)之前完成,\(B_{2,1}\) 必须在\(B_{1,1}\) 之前完成。
  • 列间依赖,就是 partitions(设备) 之间的依赖。从图上看是实线,就是 \(F_{1,1}\) 必须在 \(F_{1,2}\)之前完成,即第一个设备必须在第二个设备之前完成,而且第一个设备的输出是第二个设备的输入。

如上图,我们需要完成行,列两方面的依赖。

  • 行间依赖是用 Join & Fork 来保证,利用空张量完成了依赖关系的设定,确保 batches[i-1] 在 batches[i] 之后完成。
  • 列间依赖是通过 batches[i] = batch 完成,利用 PermuteBackward 来完成了设备之间的依赖。

至此,我们完成了执行顺序和依赖关系的设定,下一篇我们介绍如何并行处理。

0xFF 参考

Markdown公式用法大全

markdown中公式编辑教程

https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA学习:基础知识小结

CUDA随笔之Stream的使用

NVIDIA解决方案架构师深度解析大规模参数语言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

https://github.com/NVIDIA/apex/

https://github.com/justheuristic/prefetch_generator

https://pytorch.org/tutorials/intermediate/model_parallel_turotial.html

https://pytorch.org/docs/stable/autograd.html

https://pytorch.org/docs/notes/cuda.html

https://zhuanlan.zhihu.com/p/61765561

https://pytorch.apachen.org/docs/1.7/64.html

https://zhidx.com/p/217999.html

相关文章: