rossiXYZ

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

读论文有一种原则是:本领域最经典的论文,近5年最热的论文,近1年最新的论文。按照这个原则,本文主要介绍一篇Tensorflow 经典论文 Implementation of Control Flow in TensorFlow

本系列相关文章如下:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

1. 概览

本文介绍了 TensorFlow 中控制流操作符的当前设计和实现。这是一份基于原始设计的描述性文档,具体细节请参见实际源代码。本文内容是:

  • 介绍五个 TensorFlow 的核心操作符,它们是专门为处理控制流而添加的。
  • 展示高层控制流结构如何基于这五个基础操作符被编译进数据流图。
  • 解释这些数据流图如何由 TensorFlow runtime 执行,包括在一组混合设备(如CPU、GPU和TPU)上的分布式执行方式。
  • 描述如何对控制流结构进行自动求导。

本文图均来自原始论文。

2. 控制流原语

TensorFlow 中控制流的基本设计原则是:引入一个包含少量操作的简单原子操作集,在这些操作符之上来表达TensorFlow 应用的复杂控制流。我们希望这些基元是灵活且富有表现力的,可以作为高级领域特定语言(DSL)的一个良好的编译目标。它们应该与 TensorFlow 的数据流模型相兼容,并且可以方便实施并行,分布式执行以及自动微分。如下图所示,原子操作集之中有五个控制流原语运算符,其中 Switch 和 Merge 组合起来可以实现条件控制。所有五个基元一起组合则可以实现 while 循环。

图 1 基元

在 TensorFlow 中,每个 op 都在一个执行帧(execution frame)中执行,控制流原语负责创建和管理这些执行帧。对于每个 while 循环,TensorFlow 运行时会设置一个执行帧,并在执行帧内运行 while 循环的所有操作。执行帧可以嵌套。嵌套的 while 循环在嵌套的执行帧中运行。只要执行帧之间没有数据依赖关系,则来自不同执行帧的操作可以并行运行。

Switch:Switch 运算符会根据输入控制张量 p 的布尔值,将输入张量 d 转发到两个输入中的一个。只有两个输入都准备好之后,Switch 操作才会执行。

Merge:Merge 运算符将其可用的输入之一转发到其输出。只要它的任何一个输入可用,merge 运算符就会执行。如果有多个可用的输入,则无法确定它的输出。

Enter(name):Enter 操作符将其输入转发到由给定名称唯一标识的执行帧。这个 Enter 操作用于将一个执行帧中的张量传递给一个子执行帧。对于同一个子执行帧可以有多个 Enter 操作,每个操作都会使子执行帧中的张量可用(异步)。当输入可用时,Enter 操作将执行。一个新的执行帧在执行该帧第一个 Enter 操作时候被实例化。

Exit:Exit 操作符将一个张量从一个执行帧返回给它的父执行帧。一个执行帧可以有多个 Exit 操作返回到父执行帧,每个操作都异步地将张量传回给父帧。当一个 Exit 的输入可用时,该 Exit 操作就被启用。

NextIteration: 一个 NextIteration 操作符将其输入转发到当前执行帧的下一个迭代。TensorFlow 运行时会跟踪维护执行帧中的迭代信息。一个执行帧中执行的任何操作都有一个唯一的迭代 ID,这使得我们能够唯一地识别迭代计算中同一操作的不同调用(比如 hile 操作之中,某一个 op 可能会多次执行)。请注意,一个执行帧中可以有多个 NextIteration操作。当执行帧的第 N 次迭代的第一个 NextIteration 操作开始执行时,TensorFlow 运行时就开始进行第 N+1 次迭代。随着更多的张量通过执行 NextIteration 操作进入下一个迭代,新迭代中更多操作就开始执行。当一个 NextIteration 的输入可用时,它就被启用。

3. 控制流结构的编译

因为增加了这 5 个控制原语,例如 cond 和 while_loop 这样的高级编程结构就可以被编译成数据流图,从而可以被 TensorFlow 执行。我们接下来看看条件表达式和 while 循环如何在 Tensorflow 内部实现。

3.1 条件表达式

下面是构建条件表达式 cond(pred, fn1, fn2) 数据流图的高级伪代码。为了简单起见,我们忽略了实际实现中的许多重细节。读者可以在 control_flow_ops.py 中找到相关的实现细节。

# Build the graph for the true branch 
context_t = CondContext(pred, branch=1) 
res_t = context_t.Call(fn1)

# Build the graph for the false branch 
context_f = CondContext(pred, branch=0) 
res_f = context_f.Call(fn2)

# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)] 
return merges

对于条件表达式的每一个分支,我们都会为条件语境创建一个新的控制流上下文,并在上下文中调用其计算图构造函数(fn1或fn2)。条件上下文允许我们捕获任何外部张量(不是在上下文中创建的),并插入一个适当的Switch 操作来确保其进入一个分支。这保证了分支中的任何操作只有在该分支被选择时才会执行。由于 TensorFlow 模型的异步执行特点,这些外部张量可能在非常不同的时间变得可用,所以我们为每个外部张量使用一个 Switch op 来最大化并行度。

因为每个分支返回一个张量列表(ref_t或res_f),所以我们需要添加一个 Merge 操作来对该结果列表每个输出的真值/假值进行合并。同样,输出可能在不同的时间被计算,所以我们对每个输出使用一个 Merge 操作,这使我们能够尽快启用下游的计算。让我们来看一个简单的例子:

图 2 条件表达式

tf.cond(x<y, lambda: tf.add(x,z), lambda: tf.square(y))

在生成的数据流图中,Switch 操作被用来控制张量 x、y和z 的流动。在 true/false 分支中,只使用 Switch 操作的真/假输出。由于 add 的输入来自 Switch 操作的 true 分支输出,所以 add 操作只在 x<y 为真时执行。同样地,Square 操作只在 x<y 为假时执行。Add 或 Square 的结果由最后的 Merge 操作发出。如果条件表达式有多个输出,就会有多个 Merge 操作,每个输出都有一个 Merge 操作结果。

有很多种使用 Switch 和 Merge 对 cond 进行编码的方法,我们选择目前的编码方式主要是因为它使 cond 自动求导变得更简单。

3.2 while 循环

以下是构建 while 循环数据流图的高层伪代码:

while_context = WhileContext()
while_context.Enter()

# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]

# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x,x]) for x in enter_vars]

# Build the loop pred subgraph.
pred_result = pred(*merge_vars)

# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]

# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])

# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]

# Form the cycles for the loop.
for m,v in zip(merge_vars, next_vars):
    m.op._update_input(1,v)

# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars

整个 while 循环图是在 while 循环的控制流上下文之中创建的。这里的基本思路很简单。

从循环变量开始,我们为每个循环变量添加一个 Enter 操作,其后面跟着一个 Merge 操作。然后我们使用其结果(merge_vars)来建立 pred 子图,pred 子图将计算循环的终止条件。

在加入 Switch 操作后,我们使用 Switch 的 true 分支输出来构建 while 循环主体的子图。循环主体的结果需要进入下一个迭代,所以我们添加 NextIteration 操作,并将其输出连接到 Merge 操作的第二个输入。这就形成了循环,这使我们在执行图的时候可以多次重复运行同一个操作。

Switch 操作的假值输出是整个 while 循环的输出,所以我们在假值输出后面插入了 Exit 操作,并返回 Exit 操作的输出。与 cond 类似,while 循环的上下文被用来跟踪 pred 和 body lambdas 中使用的外部张量。这些外部张量被视为循环常量,我们为每个这样的外部张量自动插入一个 Enter 操作,使其可以在 while 循环上下文中访问。嵌套循环需要添加嵌套的 Enter 操作。

同样,让我们看看一个简单程序的生成图例子。

图 3 while 循环

tf.while_loop(lambda i:i<10, lambda i: tf.add(i,1),[0])

在这个例子中,我们只有一个循环变量。如果有多个循环变量,我们需要添加多个 Enter、Merge、Switch、NextIteration 和 Exit 操作。这样就可以并行执行跨循环和循环内跨迭代的操作。我们省略了在 while 循环中如何处理常量的方法。如果你想了解其细节,请看具体代码。

cond 和 while_loop 的这种转换方法可以支持条件表达式和循环的任意嵌套。例如,一个循环体可以调用另一个 while_loop,它将被递归地翻译成一个嵌套的子图。该翻译确保每个循环被静态地分配一个唯一的框架名称。

4. 实现

TensorFlow 运行时负责数据流图的执行。让我们先快速浏览一下。为了在多个设备上运行,TensorFlow 会自动将操作分配到设备集上。TensorFlow 基于设备的具体放置来自动将数据流图分割成一组子图,每个设备一个子图。当一条边被分区切分时,我们会自动插入一对发送和接收节点,用于在设备间传输张量。一对 send 和 recv 使用一个唯一的 key 进行通信,recv 会主动从 send 中提取数据(这里是特色)。例如,下图是将一个图划分到两个设备上的结果,TensorFlow 对分区没有施加任何限制。只要某个节点的计算可以在一个设备上完成,它就可以被分配到该设备上。

图 4 划分后的计算图

当一个子图被分配到某一个设备之后,这个子图就被该设备的本地执行器管理。执行器从源节点开始,依次执行准备好的节点。除了合并节点外,一个节点在其所有输入都可用时,就成为就绪节点。注意,子图中的所有 recv 节点都被认为是源节点。

如果没有控制流,图的执行就非常直接。每个节点都仅仅被执行一次,当所有节点都被执行过之后,执行就结束了。控制流引入了相当的复杂性。一个节点现在可以被执行任何次数(包括 0 在内)。执行器需要能够管理同一节点内多个实例的执行(可能是并发的),并确定图执行何时会完成。

为了跟踪执行过程中产生的张量,我们使用一个元组 d = (value, is_dead, tag) 来标示执行器中的张量,其中 value 是实际的张量,is_dead 是一个布尔值(用来表示该张量是否在一个未执行的条件分支上),而 tag 是唯一标识该张量(以及产生该张量的节点的执行实例)的字符串。直观地说,tag 定义了一个执行环境,在一个执行环境中,一个节点最多执行一次。标签是发送/转发之间通信 key 的一部分,以区分同一发送/转发节点之间的多个调用。执行者遵循以下执行规则(注意:一个节点的所有输入必须有相同的标签。)

Switch(p,d) = (r1,r2)
r1 = (value(d), p || is_dead(d),tag(d))
r2 = (value(d), !p || is_dead(d),tag(d))

Merge(d1,d2) = r
r = if is_dead(d1) then d2 else d1

Enter(d, frame_name) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag(d)/frame_name/0

Exit(d) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag1 where tag(d)=tag1/frame_name/n

NextIteration(d) = d1
value(d1) = value(d)
is_dead(d1) = is_dead(d)
tag(d1) = tag1/frame_name/(n+1) where tag(d) = tag1/frame_name/n

Op(d1,...,dm) = (r1,...,rn)
value(ri) = Op.Compute(value(d1),...,value(dm)) if !is_dead(ri)
is_dead(ri) = any(is_dead(d1),...,is_dead(dm)), for all i
tag(ri) = tag(d1), for all i

最后一条规则是针对所有非控制流节点的。请注意,只有当所有的输入都有效时,才会进行实际的计算。如果有一个无效输入,我们将跳过计算并向下游传播一个 dead 信号。这种 dead 信号的传播可以被用来支持控制流的分布式执行。

5. 分布式条件表达式

对于分布式执行来说,一个条件表达式可能被切分到多个设备上,如下图所示:

图 5 切分表达式

由于任何 recv 节点都是一个随时无条件启动的源节点,所以,即使设备 B 上的 recv 节点是在条件表达式的未选择分支之内,它也可能会执行。为了使未选择分支上的 recv 的执行合理化,我们在设备间把 is_dead 标志通过 send 节点发送到 recv 节点。传播可以在任何数量的设备上继续进行。这个简单的传播机制可以处理嵌套条件的分布式执行,也有助于 while 循环的分布式执行。

6. 分布式的 while 循环

对于分布式执行,一个 while 循环,特别是循环主体,可以被切分到多个设备上。如果我们简单地应用切分方案:只是为跨设备的边插入 send/recv 节点,那么设备上的本地执行器将缺少足够的信息来正确运行 while 循环。

图 6 切分控制流简单方案

让我们用一个简单的例子来说明这些问题。在上面的例子中,Op 在循环体中,被分配给设备B。一个简单切分会将 Switch 到 Op 的边拆分,插入一对 send/recv 节点,由这对节点完成跨设备数据传输。然而,这是不可行的,因为设备 B 不知道 recv 和 Op 节点是一个 while 循环的一部分,这样设备 B 在一个迭代后就会终止执行。解决方案是重写数据流图,在每个分区添加一个控制循环状态机(如下图设备 B 的右下角所示)。控制循环 Enter 节点是一个标量 0。

图 7 切分控制流改进方案

这些控制循环提供了足够的信息,这样通过发送/接收节点相互通信,就可以使设备上的执行器能够像以前一样独立运行。请注意,图中的虚线是控制边。让我们先看一下基本用例,即 while 循环只运行 0 次迭代。

  • 在设备 A 上,节点 Enter、Merge、P 和 Switc 依次被执行。因为 P 是 false,所以连接到 Switch 的 Send 会向设备 B 传播一个死信号,这样 Exit 也会运行,从而使循环之外依赖这个 Exit 的节点能够同时执行。连接到P 的 Send将 向设备 B 发送布尔张量 False,这样 Recv 也可以被执行,其会等待来自设备 B 的值。
  • 在设备 B 上,Enter 触发了循环,接下来依次执行节点 Enter 和 Merge。Merge 的执行使两个 Recv 得以执行。Switch 的 Recv 会收到 False,所以 Next 会得到一个死张量,于是停止了循环。Op 的 Recv 会得到一个死张量,所以 Op 的 Send 会把一个死张量送回设备 A,此时,设备 B 没有未完成的操作,所以执行结束。
  • 在设备 A 上,Recv for Next 得到了一个死张量。Next 运行,由于它停止了死循环的传播,设备 A 没有未完成的操作,所以执行结束。

我们接下来看看 while 循环运行一个或多个迭代。

  • 在设备 A 上,由于 P 在第一次迭代时为真,一个实数张量被发送到设备 B。同时 Recv 被执行,等待来自设备B 返回的值。

  • 在设备 B 上,控制循环状态机运行并启用 Recv。Recv 为 Op 从设备 A 得到一个实数张量;Op 被执行,Send 将一个实数张量送回设备 A。执行 Next 和 Merge,进一步启用下一个迭代的 Recv。

  • 在设备 A 上,Recv 得到一个实数张量。然后执行 Next、Merge 和 P。根据 P 的值,将执行基本情况或新的迭代。

请注意,在执行过程中存在大量的并行性。例如,设备 B 一旦收到 P 的值,就可以开始下一个迭代或退出。一个参与设备可以有多个迭代在并行运行,而且两个参与设备可以同时在同一个循环的不同迭代中工作。

分布式执行 while 循环的开销是每个参与设备在每次迭代时都需要从产生 P 的设备那里接收一个布尔张量,考虑到执行中的并行性,开销在很大程度上应该是与计算重叠,因此可以忽略。

下面显示了当一个 while 循环被划分到多个设备上时,数据流图是什么样子的。一个控制循环被添加到每个分区中,并控制 while 循环中的 Recvs。重写后的图在语义上与原始图是等价的。

图 8 重写的计算图

对于嵌套的 while 循环,我们按如下方式把控制循环堆叠起来。注意,如果一个设备只有外层循环的节点,我们将不会在其上添加任何与内层循环有关的控制循环结构。

图 9 嵌套

7. 自动微分

TensorFlow 支持自动求导。例如,用户可以定义一个带有损失函数的神经网络,而 TensorFlow 将自动推导并构建反向传播数据流图。本节解释了 TensorFlow 如何在有 cond 和 while_loop 的情况下自动构建反向传播图。我们假设读者对自动反向传播的工作方式有一定的了解。(参见链接 [1],这是一篇关于反向传播的优秀文章)。

反向传播算法以反向顺序遍历前向图中的操作,并通过调用操作注册的梯度函数逐步构建梯度图。一个操作的梯度函数定义了计算该操作梯度的子图。梯度函数可能会使用到运算的输入/输出值,因此在前向计算中产生的一些张量将被保留一段时间,直到它在反向传播之中被使用。例如,下面显示了一个前向运算和它的梯度图。G(Op) 是Op 的梯度子图。x 和 y 的值将被保存在内存中,直到 G(Op) 被执行。

图 10 反向传播

一旦构建了整个数据流图,TensorFlow 运行时就会自动对图进行分割,并将执行分布在多个设备上。因此,TensorFlow 中的梯度计算也将被分配到多个设备上运行。

直观地讲,在 cond 和 while_loop 的上下文之中,控制流算子的反向传播以如下方式进行反向传播。Exit 的梯度是 Enter;Switch 的梯度是 Merge(对于cond)或者 NextIteration 之后接着一个 Merge(对于while_loop);Merge 的梯度是 Switch;NextIteration 的梯度是 Identity;Enter 的梯度是 Exit。TensorFlow 支持嵌套条件和while循环的反向传播。

7.1 条件表达式的反向传播

直观地说,cond(p, fn1, fn2) 的梯度为 cond(p, g_fn1, g_fn2),其中 g_fn1 和 g_fn2 分别为 fn1 和 fn2 的梯度。下面显示了当 cond 没有嵌套在 while 循环中,cond 的基本反向传播操作。我们假设 Op 位于 cond 的 true 分支上。如果 cond 被嵌套在 while 循环,那么它需要做更多的工作来记住前向循环每次迭代的 p 值。我们将在后面看while 循环的反向传播时讨论这个问题。

图 10 条件表达式的反向传播

前向传播之中的 Merge 在后向传播之中被转化为 Switch,它使用与前向 Switch 相同的谓词 p。梯度 g 被反推到Switch 的两个分支。

前向 Switch 被转化为 Merge。如果前向 Switch 中只有一个分支在前向传播之中被用到了,我们会添加一个零输入到反向传播的 Merge,如下图所示,以确保在反向传播之中总有一个活跃的梯度流经 Merge。这个零输入被一个 Switch 来控制,所以它只在 p 为 false 时才会被发送到 Merge。

图 12 Switch 转换

7.2 While 循环的反向传播

直观地说,while_loop(pred, body) 的梯度也是以 while loop 的形式存在。

def pred(i, _): return i < N

while_loop(pred, g_body, [0] + g_vars) 

其中 N 是前向传播 while 循环运行的迭代次数,g_body 是前向循环体的梯度,g_vars 是循环变量的初始值。我们将在后面看到,g_vars 包括前向 while 循环变量的初始梯度。下面是一个 while 循环的前向传播和反向传播图。

图 13 While 循环的反向传播

请注意,Backprop 循环由 N 控制,即前向循环运行的迭代次数。这意味着我们假设 pred 是不可训练的。G(Body) 是 Body 的梯度。Body 可能再次包含 while 循环,所以这个结构可能会递归地出现,以处理嵌套的 while 循环。

到目前为止,这个描述是相当过度简化了。实际上,在图的构造过程中,N 并不是静态已知的。更重要的是,G(Body) 可能会使用前向传播过程中产生的值,我们希望保留这些值,以避免在反推过程中重新计算它们。解决方案是重写前向 while 循环的图,对于反向传播之中需要的值,增加计算和/或保存的逻辑。

为了计算 N,我们在前向 while 循环中加入以下子图(计算 N 的逻辑)。因此,N 将由前向循环动态计算,并作为后向循环的计数循环变量的初始值。

图 14 计算逻辑

为了在反向传播循环中重用前向传播计算出来的数值,我们在构建反向传播 while 循环的过程中,自动检测反向传播中需要的前向值。对于每个这样的前向值 x,我们自动引入一个堆栈,并在前向循环中添加节点,以便在每次迭代时将其值保存到堆栈中。反向传播循环以相反的顺序使用堆栈中的值。堆栈位于前向和反向传播循环之外,由两个循环共享(所以下图有两个 Enter)。

图 15 循环共享

实际的计算图构造实际上比这更微妙和复杂。下面是一些问题。

  • 为了保证正确性,我们需要确保堆栈的 push 和 pop 是按其各自循环的迭代来排序的。我们还需要确保前向传播的堆栈必须在后向传播的堆栈之前完成排序。这些顺序是通过控制边来完成的。
  • 为了提高性能,我们使堆栈 push 和 pop 操作成为异步的,因此它们可以与实际计算并行运行。例如,op(甚至是未来的迭代)可以与 push 并行运行。
  • 如果 op 在一个嵌套在 while 循环内的 cond 里面,那么入栈和出栈操作必须由 cond 的谓词进行适当的保护。
  • 如果某个值在反向传播之中被缩减操作(如 Shape、Rank或Size)处理,我们将缩减操作移到前向循环中以减少内存的使用。

如前所述,Enter 的梯度是 Exit。对于循环变量,这就是它的全部作用。对于循环常量,我们还添加了一个子图来累积它们的梯度,如下图所示。

图 16 累计梯度

假设 x 是前向传播中的一个循环常数。在 Backprop 中,每次迭代都会为 x 产生一个 partial gradient。因此,我们在反向传播过程中添加小的累积子图,然后将所有这些部分梯度加在一起。最终结果 \(g_x\) 是所有偏导数的总和。注意,积累是 eagerly 地进行的,以并行迭代的次数为界。这与 static unrolling 不同,在 static unrolling 中,AddN 需要所有的部分梯度在同一时间生效。

这种结构对嵌套条件和循环都有效。对于嵌套在 while 循环中的条件式,我们引入一个堆栈来保存每次前向迭代的谓词值,并在反向 prop 中使用堆栈中的值(以相反的顺序)。对于嵌套的循环,当我们遇到嵌套在循环体中的内部 while 循环时,会递归地调用这个结构。

一个重要的优化是内存交换(memory swapping)。正如我们所看到的,对于每个在 backprop 中需要的前向值 v,我们将其在所有迭代中的值 \(v_1,...,v_N\)保存在一个堆栈中,所以我们会在 backprop 中重使它们。这对于在内存有限的设备(如GPU)上进行训练是一个限制。我们使用内存交换来异步地将存储在堆栈中的值从 GPU 移动到 CPU,并在 Backprop 中需要时将它们移回 GPU 内存中。

0xFF 参考

Implementation of Control Flow in TensorFlow

tensorflow源码解析之distributed_runtime

TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems,

TensorFlow: A system for large-scale machine learning

Implementation of Control Flow in TensorFlow

Dynamic Control Flow in Large-Scale Machine Learning

Control Flow in Tensorflow TF中的控制流解析

tensorflow control flow 2---the implementation of control flow

https://blog.csdn.net/zhenhailiu/article/details/80466920

链接

[1] http://colah.github.io/posts/2015-08-Backprop/

相关文章: