rossiXYZ

[源码解析] TensorFlow 之 分布式变量

在 TensorFlow 之中,分布式变量是在多个设备上创建的变量。Mirrored variable 和 SyncOnRead variable 是两个例子。本文就对分布式变量进行分析。我们通过一系列问题来引导分析:

  • 创建如何调用到 Strategy 这里?
  • 如何生成 Mirrored Variable?
  • 如何把张量分发到各个设备上?
  • 如果对外保持一个统一的视图?
  • 变量之间如何保持一致?

依然安利两个大神:

[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),虽然其分析的不是最新代码,但是建议对 TF 内部实现机制有兴趣的朋友都去阅读一下,绝对大有收获。
https://home.cnblogs.com/u/deep-learning-stacks/ 西门宇少,不仅仅是 TensorFlow,其公共号还有更多其他领域,业界前沿。

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

[源码解析] TensorFlow 分布式环境(1) --- 总体架构

[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(4) --- WorkerCache

[源码解析] TensorFlow 分布式环境(5) --- Session

[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑

[源码解析] TensorFlow 分布式环境(8) --- 通信机制

[翻译] 使用 TensorFlow 进行分布式训练

[源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

1. MirroredVariable

tf.distribute.MirroredStrategy 支持在一台机器的多个 GPU 上进行同步分布式训练。该策略会为每个 GPU 设备创建一个副本。模型中的每个变量都会在所有副本之间进行镜像。这些变量将共同形成一个名为 MirroredVariable 的单个概念上的变量。这些变量会通过应用相同的更新彼此保持同步。

图 1 MirroredVariable

具体使用代码示例如下:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
# Variable created inside scope:
with strategy.scope():
  mirrored_variable = tf.Variable(1.)

# Variable created outside scope:
regular_variable = tf.Variable(1.)

打印结果如下:

>>> mirrored_variable
  MirroredVariable:{
    0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
    1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
  }

>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>  

或者也可以参见 tensorflow/python/module/module_test.py 之中的示例.

def test_supports_distributed_variables(self):
  mirrored = distributed_values.MirroredVariable(
      None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
  tpu = tpu_values.TPUMirroredVariable(
      strategy=None, values=[variables.Variable(42.)], aggregation=None)
  aggregating = ps_values.AggregatingVariable(
      strategy=None, v=variables.Variable(1.), aggregation=None)

  m = module.Module()
  m.a = mirrored

1.1 定义

MirroredVariable 注释之中指出其作用是 :保存一个从副本到变量的映射,这些变量的值保持同步。具体没有任何新增成员变量,只是实现了一些成员函数。

class MirroredVariable(DistributedVariable, Mirrored):
  """Holds a map from replica to variables whose values are kept in sync."""

  def _update_replica(self, update_fn, value, **kwargs):
    return _on_write_update_replica(self, update_fn, value, **kwargs)

  def scatter_min(self, *args, **kwargs):
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_min(*args, **kwargs)
    return super(MirroredVariable, self).scatter_min(*args, **kwargs)

  def scatter_max(self, *args, **kwargs):
    if values_util.is_saving_non_distributed():
      return self._primary.scatter_max(*args, **kwargs)
    return super(MirroredVariable, self).scatter_max(*args, **kwargs)

  def scatter_update(self, *args, **kwargs):
    if values_util.is_saving_non_distributed(): # 非分布式情况
      # 直接返回本地数值
      return self._primary.scatter_update(*args, **kwargs)
    # 否则进行分布式处理
    return super(MirroredVariable, self).scatter_update(*args, **kwargs)

  def _get_cross_replica(self):
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(Mirrored._get_cross_replica(self))

我们以 scatter_update 为例看看,当不是分布式时候,其会直接调用 _primary 进行处理,否则会调用基类方法处理。另外,_update_replica 方法在更新时候会调用 _on_write_update_replica 进行副本同步,_on_write_update_replica 又会从使用上下文来进行更新,具体定义在 tensorflow/python/distribute/values.py 之中。

def _on_write_update_replica(var, update_fn, value, **kwargs):
  """Updates variables with ON_WRITE synchronization in replica context."""
  if var.aggregation == vs.VariableAggregation.NONE:
    return update_fn(var._get_on_device_or_primary(), value, **kwargs) 

    aggregated_value = apply_aggregation_replica_context(
        value, var.aggregation, var)
    values_util.mark_as_unsaveable()

    return ds_context.get_replica_context()._update(  
        var,
        update_fn,
        args=(aggregated_value,),
        kwargs=kwargs,
        group=True)

  else:

    def merge_fn(strategy, value, **kwargs):
      """Aggregate values and update all variables in cross replica context."""
      v = values_util.apply_aggregation(strategy, value, var.aggregation, var)
      return var._update_cross_replica(update_fn, v, **kwargs)  

    return ds_context.get_replica_context().merge_call(
        merge_fn, args=(value,), kwargs=kwargs)

只看这些成员方法,我们很难对 MirroredVariable 有一个清晰认识,我们还是需要从其类体系入手来分析。

1.2 相关类

1.2.1 类体系

MirroredVariable 类体系如下,我们会在逐一分析之后,再最终进行汇总。

图 2 MirroredVariable 类体系

1.2.2 DistributedValues

我们首先看看 DistributedValues。

图 3 DistributedValues

分布式变量(DistributedValues)由基类 tf.distribution.DistributedValues 表示。 tf.distributed.DistributedValues 概念适合表示多个设备上的值,它包含一个从副本ID到值的映射。

tf.distributed.DistributedValues 包含每个副本的一个值。根据子类的不同,这些值可以在更新时同步,也可以在需求时同步,或者从不同步。 tf.distributed.DistributedValues 可以规约(reduce)以获得跨副本的单一值来作为 tf.distributed.Strategy.run 的输入,或使用 tf.distributed.Strategy.experimental_local_results 检查每个副本的值。

DistributedValues 作为基类不应该被直接实例化。而应该在 distribution strategy 之中创建其子类实例,具体可以通过在 tf.distribution.DistributedDataset 迭代或者通过 tf.distribution.Strategy.run 创建。

tf.distributed.DistributedValues 的两种代表性类型是 "PerReplica" 和 "Mirrored" 值。

  • "PerReplica"值存在于 worker 设备上,每个副本有不同的值。它们是由 tf.distribution.Strategy.experimental_distribute_dataset 和 tf.distribution.Strategy.distribution_datasets_from_function 返回的分布式数据集的迭代产生。它们也是由 tf.distribution.Strategy.run 返回的典型结果。

  • "Mirrored"值与 "PerReplica"值类似,只是所有副本上的值都是一样的。我们可以通过使用任何副本上的值,在跨副本上下文中安全地读取 "Mirrored"值。

定义

DistributedValues 有 两个成员变量比较重要,_values 和 _primary。初始化变量被设置到 _values 数组之中,数组第一个变量被复制为 _primary。

因为派生类会用到,所以我们分析 DistributedValues 的几个成员函数。

  • _get_on_device_or_primary 就是返回本副本对应的value,或者直接返回 _primary 对应的value。
  • _get_cross_replica :返回跨副本value,这个留给派生类实现。
  • _get :如果得到replica_id,就调用 _get_cross_replica 返回跨副本数值,或者返回本地数据。

概念图如下:

图 4 DistributedValues

DistributedValues 具体代码如下:

@tf_export("distribute.DistributedValues", v1=[])
class DistributedValues(object):
  """Base class for representing distributed values.

  A subclass instance of  tf.distribute.DistributedValues  is created when
  creating variables within a distribution strategy, iterating a
   tf.distribute.DistributedDataset  or through  tf.distribute.Strategy.run .
  This base class should never be instantiated directly.
   tf.distribute.DistributedValues  contains a value per replica. Depending on
  the subclass, the values could either be synced on update, synced on demand,
  or never synced.

   tf.distribute.DistributedValues  can be reduced to obtain single value across
  replicas, as input into  tf.distribute.Strategy.run  or the per-replica values
  inspected using  tf.distribute.Strategy.experimental_local_results .
  """

  def __init__(self, values):
    """Should only be called by subclass __init__."""
    self._values = tuple(values)

  def _get(self):
    """Returns the value for the current device or raises a ValueError."""
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None:
      return self._get_cross_replica() # 返回跨副本信息
    else:
      return self._values[replica_id] # 返回本地信息

  def _get_cross_replica(self):
    raise NotImplementedError(
        "DistributedValues._get_cross_replica should be implemented by "
        "sub-classes which support cross-replica accesses.")

  def _get_on_device_or_primary(self):
    """Returns value in same replica or device if possible, else the _primary."""
    # 获取当前副本id
    replica_id = values_util.get_current_replica_id_as_int()
    if replica_id is None: # 如果没有副本id,则看看本机上设备集合
      # Try to find a value on the current device.
      # 拿到当前设备名字,current_device 是一个string
      current_device = device_util.canonicalize(device_util.current())
      for value in self._values: # 遍历
        if device_util.canonicalize(value.device) == current_device:
          return value # 返回
      return self._primary # 返回 _primary
    else:
      # 返回本副本对应的value
      return self._values[replica_id]

  @property
  def _primary(self):
    """Returns a representative component."""
    return self._values[0]

  @property
  def _devices(self):
    return tuple(v.device for v in self._values)

上面代码之中大量用到了 get_current_replica_id_as_int,此函数定义在 tensorflow/python/distribute/values_util.py 之中,作用是获取当前副本id。

def get_current_replica_id_as_int():
  """Returns the current replica ID as an integer, or  None ."""
  replica_context = ds_context.get_replica_context()
  if replica_context:
    replica_id = replica_context._replica_id
    if not isinstance(replica_id, int):
      replica_id = tensor_util.constant_value(replica_id)
  else:
    replica_id = distribute_lib.get_update_replica_id()
  return replica_id
使用

我们从源码之中找出一些使用例子如下,都是使用 MirroredStrategy 来获取 DistributedValues。

# 1. Created from a  tf.distribute.DistributedDataset :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)

# 2. Returned by  run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
   ctx = tf.distribute.get_replica_context()
   return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)

# 3. As input into  run :
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
   return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))

# 4. Reduce value:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                 distributed_values,
                                 axis = 0)

# 5. Inspect local replica values:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(distributed_values)
print(per_replica_values)

# 输出结果
#  (<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
#   <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)

1.2.3 DistributedDelegate

接下来我们看看 DistributedDelegate 。

图 5 DistributedDelegate

DistributedDelegate 作用是在 DistributedValues 之上增加了计算功能。具体是通过 _get_as_operand 来调用基类 DistributedValues 的 _get 方法,得到value,然后进行计算。

图 6 如何计算

DistributedDelegate 定义如下,省略部分代码。

class DistributedDelegate(DistributedValues):
  """A map from device to values; acts as the same type as the values."""

  def __getattr__(self, name):
    # The '_use_resource_variables' and the attrs starts with '_self' are used
    # for restoring the saved_model proto, and '_attribute_sentinel' is used for
    # Layer tracking. At the point these attrs are queried, the variable has not
    # been initialized. Thus it should not query those of the underlying
    # components.
    if name.startswith("_self_") or name in ("_use_resource_variables",
                                             "_attribute_sentinel",
                                             "_distributed_container"):
      return super(DistributedDelegate, self).__getattr__(name)

    # This allows copy.copy(DistributedDelegate). When copying an object,
    # copy.copy doesn't invoke its __init__ method, instead it makes a new
    # empty object, then copies the attributes over. copy.copy looks for
    # attributes like "__getstate__" in case the object implements its custom
    # copying. Since DistributedDelegate doesn't have those attributes defined,
    # __getattr__ will be invoked, which tries to access "_values" attributes,
    # but that doesn't exist either because this is an empty object, and again
    # __getattr__ is invoked, leading to an infinite recursion.
    if name == "_values":
      raise AttributeError()

    # TODO(priyag): This needs to be made robust against pitfalls from mix use
    # __getattr__ and @property. See b/120402273.
    return getattr(self._get(), name)

  @property
  def values(self):
    """Returns the per replica values."""
    return self._values

  def _get_as_operand(self):
    """Returns the value for operations for the current device.

    Some implementations, e.g.  TPUMirroredVariable , are not able to return the
    value type within a replica context. They can, however, return a value that
    can be used by the operations below.
    """
    return self._get()

  def __add__(self, o):
    return self._get_as_operand() + o

  def __radd__(self, o):
    return o + self._get_as_operand()

  def __sub__(self, o):
    return self._get_as_operand() - o

  def __rsub__(self, o):
    return o - self._get_as_operand()

  # 省略大部分代码

1.2.4 PerReplica

PerReplica 的作用是:持有一个map,用来维持从副本到未同步value的映射。

class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
  """Holds a map from replica to unsynchronized values."""

  @property
  def _type_spec(self):
    return PerReplicaSpec(
        *(type_spec.type_spec_from_value(v) for v in self._values))

  @property
  def values(self):
    """Returns the per replica values."""
    return self._values

1.2.5 Mirrored

接着我们来到 Mirrored这里。

图 7 Mirrored

Mirrored 代表了在多个设备上创建的变量,其通过对每个副本应用相同的更新来保持变量的同步。镜像变量(Mirrored variables)是用 tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...) 创建的。 通常它们只用于同步训练。

回忆一下 DistributedValues 的功能,其保存一个从副本到值的映射,这些值将保持同步,其 _get_cross_replica 方法没有实现。而 Mirrored 的目的是在跨副本模式(cross-replica mode)下可以直接使用。所以 Mirrored 这里实现了 _get_cross_replica。_get_cross_replica 就是调用基类 DistributedValues 的 _get_on_device_or_primary 方法(具体请参见对应小节),作用是返回本副本对应的数值,或者直接返回 _primary 对应的数值。

概念图如下:

图 8 Mirrored 如何计算

Mirrored 定义如下:

# Note that unlike PerReplica, Mirrored values inherit from
# DistributedDelegate and so can be used directly in cross-replica mode.
class Mirrored(DistributedDelegate):
  """Holds a map from replica to values which are kept in sync."""

  def _get_cross_replica(self):
    return self._get_on_device_or_primary() # 调用基类 DistributedValues 的方法

  def _as_graph_element(self):
    obj = self._get() # 调用基类 DistributedValues 的方法
    conv_fn = getattr(obj, "_as_graph_element", None)
    if conv_fn and callable(conv_fn):
      return conv_fn()
    return obj

1.2.6 Policy

我们接下来看看分布式策略。

图 9 分布式策略

VariablePolicy

VariablePolicy 是分布式策略的基类,其定义了分布式变量的同步和聚合的策略。在 tf.distribution 范围内创建变量时,鉴于 tf.Variable 上设置了 synchronization 和 aggregation 参数, tf.distribution 会创建一个适当的策略对象并将其分配给分布式变量。所有的变量操作都被委托给相应的策略对象来完成。

class VariablePolicy(object):
  """Policy defining synchronization and aggregation of a distributed variable.

  Given  synchronization  and  aggregation  parameters set on a  tf.Variable 
  during variable creation within  tf.distribute  scope,  tf.distribute  creates
  an appropriate policy object and assigns it to the distributed variable. All
  variable operations are delegated to the respective policy object.
  """

  def __init__(self, aggregation):
    self._aggregation = aggregation

  def value(self):
    raise NotImplementedError(
        "VariablePolicy.value should be overriden by sub-classes.")

  def _is_mirrored(self):
    raise NotImplementedError(
        "VariablePolicy._is_mirrored should be overriden by sub-classes.")

  def _as_graph_element(self, _):
    raise NotImplementedError(
        "VariablePolicy._as_graph_element should be overriden by sub-classes.")

  def _get_cross_replica(self, var):
    raise NotImplementedError(
        "VariablePolicy._get_cross_replica should be overriden by sub-classes.")

  def _update_replica(self, var, update_fn, value, **kwargs):
    raise NotImplementedError(
        "VariablePolicy._update_replica should be overriden by sub-classes.")
OnReadPolicy

OnReadPolicy 是读取策略,比如其成员变量 _get_cross_replica 就会调用 var.distribute_strategy.reduce 来完成读取。

class OnReadPolicy(VariablePolicy):
  """Policy defined for  tf.VariableSynchronization.ON_READ  synchronization.

  This policy is created when  synchronization  is set to
   tf.VariableSynchronization.ON_READ  and  aggregation  is set to any of the
  values allowed by the  tf.VariableAggregation  enum such as  NONE ,  SUM ,
   MEAN  or  ONLY_FIRST_REPLICA when creating a  tf.Variable  in  tf.distribute 
  scope.
  """

  def _is_mirrored(self):
    return False

  def value(self, var):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
          return var._get_replica(0).value()  
        return var._get_cross_replica()  
      else:
        return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if ds_context.in_cross_replica_context():
        return ops.convert_to_tensor(var._get_cross_replica())  
    return var._get()._as_graph_element()  

  def _get_cross_replica(self, var):
    if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
      return var._get_replica(0)  # 从第一个副本读取
    if self._aggregation == vs.VariableAggregation.SUM:
      values_util.mark_as_unsaveable() # 不能更新
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      # 调用 distribute_strategy 完成规约
      return var.distribute_strategy.reduce(
          reduce_util.ReduceOp.from_variable_aggregation(self._aggregation),
          var,
          axis=None)

  def _update_replica(self, var, update_fn, value, **kwargs):
    return update_fn(var._get_on_device_or_primary(), value, **kwargs)  

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True):
    """Adds a value to this variable."""
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_add_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign_add(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)

  def assign(self, var, value, use_locking=False, name=None, read_value=True):
    with ds_context.enter_or_assert_strategy(var.distribute_strategy):
      if (ds_context.in_cross_replica_context() and
          not values_util.in_replica_update_context()):
        values_util.mark_as_unsaveable()
        return values_util.on_read_assign_cross_replica(
            var, value, read_value=read_value)
      else:
        return values_util.on_write_assign(
            var,
            value,
            use_locking=use_locking,
            name=name,
            read_value=read_value)
    
  # 省略大部分代码
OnWritePolicy

OnWritePolicy 类用来实现写策略。其主要是调用 var._get_on_device_or_primary() 来完成各种操作,比如 _get_cross_replica 就是调用 var._get_on_device_or_primary() 来完成操作。 而且也调用了 values_util 之中的各种基础操作。

class OnWritePolicy(VariablePolicy):
  """Policy defined for  tf.VariableSynchronization.ON_WRITE  synchronization.

  This policy is created when the following  synchronization  and  aggregation 
  parameters are specified when creating a  tf.Variable  in  tf.distribute 
  scope and  synchronization  is equal to  tf.VariableSynchronization.ON_WRITE 
  or  tf.VariableSynchronization.AUTO .
  """

  def _is_mirrored(self):
    return True

  def value(self, var):
    return var._get_on_device_or_primary().value()  

  def _as_graph_element(self, var):
    return var._get_on_device_or_primary()._as_graph_element()  

  def _get_cross_replica(self, var):
    # Return identity, to avoid directly exposing the variable to the user and
    # allowing it to be modified by mistake.
    return array_ops.identity(var._get_on_device_or_primary())  

  # 调用 update_fn 和 _on_write_update_replica 来完成对应操作
  def _update_replica(self, var, update_fn, value, **kwargs):
    if var.aggregation == variables_lib.VariableAggregation.NONE:
      return update_fn(var._get_on_device_or_primary(), value, **kwargs)  
    return _on_write_update_replica(var, update_fn, value, **kwargs)

  def assign(self, var, value, use_locking=False, name=None, read_value=True):
    return values_util.on_write_assign(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  def assign_add(self,
                 var,
                 value,
                 use_locking=False,
                 name=None,
                 read_value=True):
    # 调用 values_util 完成工作
    return values_util.on_write_assign_add(
        var, value, use_locking=use_locking, name=name, read_value=read_value)

  # 这里后续会提到
  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
    return values_util.scatter_update(
        var, sparse_delta, use_locking=use_locking, name=name)

  def get_saveable(self, var, primary_var, name):
    """Saveable ops for AUTO variables."""
    return values_util.get_on_write_saveable(var, primary_var, name)

  def get_restore_ops(self, var, tensor):
    return values_util.get_on_write_restore_ops(var, tensor)

  # 省略大部分代码
values_util

上面两种策略都使用了 on_write_assign_add ,其定义在 ensorflow/python/distribute/values_util.py 之中。

def on_write_assign_add(var, value, use_locking=False, name=None,
                        read_value=True):
  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
  return var._update(  
      update_fn=assign_add_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)

OnWritePolicy 也使用了 values_util 定义的 scatter_update,发现其还是调用回到了 var._update。

def scatter_update(var, sparse_delta, use_locking=False, name=None):
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update( 
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)

1.2.7 DistributedVariable

顺着类关系,我们最后来到 DistributedVariable,这里其实是 MirroredVariable 的主要功能所在。

图 10 DistributedVariable

DistributedVariable 持有从副本到变量的映射,对于 MirroredVariable 来说,self._policy 就是 OnWritePolicy,具体更新变量就是通过 _policy 完成。

class DistributedVariable(DistributedDelegate, variables_lib.Variable,
                          core.Tensor):
  """Holds a map from replica to variables."""

  def __init__(self, strategy, values, aggregation, var_policy=None):
    if (aggregation == variables_lib.VariableAggregation.MEAN and
        not values[0].dtype.is_floating):
      raise ValueError(
          "creating distributed tf.Variable with aggregation=MEAN and a "
          "non-floating dtype is not supported, please use a different "
          "aggregation or dtype")
    self._distribute_strategy = strategy
    self._aggregation = aggregation
    super(DistributedVariable, self).__init__(values)
    self._common_name = self._primary.name.split(":")[0]
    # Use a weakref to make it easy to map from the contained values
    # to the container without introducing a reference cycle.
    for v in values:
      v._distributed_container = weakref.ref(self)  # pylint: disable=protected-access

    # Packed variable is used to reduce the overhead of function execution.
    # For a DistributedVariable, only one variable handle is captured into a
    # function graph. It's only supported in eager mode.
    if ops.executing_eagerly_outside_functions() and getattr(
        strategy, "_enable_packed_variable_in_eager_mode", False):
      name = "%s/packed/" % self._common_name
      self._packed_var = packed.PackedDistributedVariable(values, name=name)
    else:
      self._packed_var = None

    # tf.keras keeps track of variables initialized using this attribute. When
    # tf.keras gets the default session, it initializes all uninitialized vars.
    # We need to make _keras_initialized a member of DistributedVariable because
    # without this it will use  __getattr__  which will delegate to a component
    # variable.
    self._keras_initialized = False
    # Typically, a  DistributedVariable 's initializer is composed of the
    # initializers of the components variables. However, in some cases, such as
    # when restoring from a checkpoint, we may set the _initializer_op
    # property on the entire  DistributedVariable .
    self._initializer_op = None
    # Set a VariablePolicy which decides how we replicate/aggregate the given
    # variable.
    self._policy = var_policy

具体如何处理,需要看实际情况,但是最终都是归结到 strategy 或者 strategy.extended 之上。

读取

读取时候,会调用 _get_cross_replica,其内部调用 Policy。而 Policy 会调用 distribute_strategy 完成规约。

def _get_cross_replica(self):
  if values_util.is_saving_non_distributed(): 
    return self._primary # 如果是非分布式存储,就直接返回
  if self._policy:
    # 返回跨样本
    return self._policy._get_cross_replica(self)  

  raise NotImplementedError(
      "DistributedVariable._get_cross_replica requires a valid "
      "VariablePolicy. Please set the policy via the  var_policy  argument "
      "in the constructor, or override this method in sub-classes which "
      "support cross-replica accesses.")

具体如下:

图 11 DistributedVariable 读取

scatter_update

比如 scatter_update 也会调用 _policy 完成更新操作。

def scatter_update(self, sparse_delta, use_locking=False, name=None):
  if values_util.is_saving_non_distributed():
    return self._primary.scatter_update(sparse_delta, use_locking, name)
  if self._policy:
    return self._policy.scatter_update(
        self, sparse_delta, use_locking=use_locking, name=name)
  return values_util.scatter_update(
      self, sparse_delta, use_locking=use_locking, name=name)

前面在 OnWritePolicy 之中讨论过,scatter_update 最后会调用回到 DistributedVariable 自己的 _update 方法。

def scatter_update(var, sparse_delta, use_locking=False, name=None):
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update(  
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)

var._update 里面有各种运行路径,我们只选择部分分析。

def _update(self, update_fn, value, **kwargs):
  """Applies updates depending on the context.

  The method calls  _update_replica  in replica context,
   _update_cross_replica  in cross replica context, and  update_fn  in update
  context.

  If  read_value  is True, the method returns the updated Variable. If
   read_value  is False, the method returns the update  tf.Operation .

  Args:
    update_fn: A callable to pass to  strategy.extended.update  to update the
      variable. It should have the same signature as  Variable.assign() .
    value: value to be passed to  update_fn .
    **kwargs: keyword arguments to  update_fn .

  Returns:
    Updated variable or  tf.Operation .

  """
  if values_util.is_saving_non_distributed():
    return update_fn(self._primary, value, **kwargs) # 非分布式

  with ds_context.enter_or_assert_strategy(self.distribute_strategy):
    if ds_context.in_cross_replica_context():
      update_replica_id = distribute_lib.get_update_replica_id()
      if update_replica_id is not None:
        replica_value = self._get_replica(update_replica_id)
        return update_fn(replica_value, value, **kwargs)
      return self._update_cross_replica(update_fn, value, **kwargs) # 跨副本更新
    else:
      values_util.assert_replica_context(self.distribute_strategy)
      return self._update_replica(update_fn, value, **kwargs)

然后调用了 _update_cross_replica 进行跨副本更新。

def _update_cross_replica(self, update_fn, value, **kwargs):
  """Applies updates across replicas.

  Args:
    update_fn: A callable to pass to  strategy.extended.update  to update the
      variable. It should has the same signature as  Variable.assign() .
    value: value to be passed to  update_fn .
    **kwargs: remaining arguments to  update_fn .

  Returns:
    Updated variable or  tf.Operation .
  """
  values_util.mark_as_unsaveable()
  return self.distribute_strategy.extended.update(
      self, update_fn, args=(value,), kwargs=kwargs, group=True)

我们展示如下:

图 12 DistributedVariable 更新

1.2.8 存储

我们接下来看看 MirroredVariable 如何存储,可以看到,在 _saveable_factory 之中使用 _MirroredSaveable 完成存储功能。

class MirroredVariable(DistributedVariable, Mirrored):

  def _gather_saveables_for_checkpoint(self):
    """Overrides Trackable method.

    This allows both name-based and object-based save and restore of
    MirroredVariables.

    Returns:
      A dictionary mapping attribute names to  SaveableObject  factories.
    """

    def _saveable_factory(name=self._common_name):
      return _MirroredSaveable(self, self._primary, name)

    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}

_MirroredSaveable 来定义如何存储 MirroredVariable。

class _MirroredSaveable(saveable_object.SaveableObject):
  """Class for defining how to restore a MirroredVariable."""

  def __init__(self, mirrored_variable, primary_variable, name):
    self._mirrored_variable = mirrored_variable
    # 这里调用到
    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
                                                     primary_variable, name)
    super(_MirroredSaveable, self).__init__(tensor, spec, name)

  def restore(self, restored_tensors, restored_shapes):
    """Restore the same value into all variables."""
    tensor, = restored_tensors
    return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor)

get_on_write_saveable 代码如下:

def get_on_write_saveable(var, primary_var, name):
  """Return saveable spec for AUTO and ON_WRITE variables."""
  # We use a callable so that we don't have to evaluate this expression
  # in the case where we are trying to restore instead of save.
  def tensor():
    if context.executing_eagerly() and not primary_var.is_initialized():
      # A SaveSpec tensor value of  None  indicates that the variable is
      # uninitialized.
      return None
    strategy = var.distribute_strategy
    return strategy.extended.read_var(var) # 获取张量

  spec = saveable_object.SaveSpec(
      tensor=tensor,
      slice_spec="",
      name=name,
      dtype=var.dtype,
      device=primary_var.device)

  return tensor, [spec]

tensorflow/python/distribute/mirrored_strategy.py 这里会跨副本进行取值。

def read_var(self, replica_local_var):
  """Read the aggregate value of a replica-local variable."""
  if distribute_utils.is_sync_on_read(replica_local_var):
    return replica_local_var._get_cross_replica()
  return array_ops.identity(replica_local_var._get())

1.2.9 小结

经过上述分析,最终我们得到 MirroredVariable 继承体系注解版如下,其很多功能最终落实在 tf.distribute.Strategy 之上。

图 13 MirroredVariable 继承体系注解版

1.3 构建变量

在 MirroredStrategy 下创建的变量是一个 MirroredVariable。如果在策略的构造参数中没有指定设备,那么它将使用所有可用的 GPU。如果没有找到 GPU,它将使用可用的 CPU。请注意,TensorFlow 将一台机器上的所有 CPU 视为一个单一的设备,并在内部使用线程进行并行化。我们接下来看看如何构建 MirroredVariable。

1.3.1 StrategyBase

首先,在 tensorflow/python/distribute/distribute_lib.py 之中有如下代码,说明关于 scope 的使用,还是 _extended 起了作用。

def scope(self):
  """Returns a context manager selecting this Strategy as current.

  Inside a  with strategy.scope():  code block, this thread
  will use a variable creator set by  strategy , and will
  enter its "cross-replica context".

  Returns:
    A context manager.
  """
  return self._extended._scope(self)  

1.3.2 StrategyExtendedV2

于是我们来到了 StrategyExtendedV2。StrategyExtendedV2 这里调用了 creator_with_resource_vars 来提供一种如何创建变量的机制,creator_with_resource_vars 内部则调用派生类的_create_variable 来建立变量。

def _scope(self, strategy):
  """Implementation of tf.distribute.Strategy.scope()."""

  def creator_with_resource_vars(next_creator, **kwargs):
    """Variable creator to use in  _CurrentDistributionContext ."""
    _require_strategy_scope_extended(self)
    kwargs["use_resource"] = True
    kwargs["distribute_strategy"] = strategy

    # Unwrap  initial_value  if it is a  CheckpointInitialValue  to avoid
    # dereferencing a  Tensor  that is without a  name . We still need to
    # propagate the metadata it's holding.
    if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
      kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
    elif isinstance(kwargs["initial_value"],
                    trackable.CheckpointInitialValueCallable):
      checkpoint_restore_uid = kwargs[
          "initial_value"].checkpoint_position.restore_uid
    elif (isinstance(kwargs["initial_value"], functools.partial) and
          isinstance(kwargs["initial_value"].func,
                     trackable.CheckpointInitialValueCallable)):
      # Some libraries (e.g, Keras) create partial function out of initializer
      # to bind shape/dtype, for example:
      #  initial_val = functools.partial(initializer, shape, dtype=dtype)
      # Therefore to get the restore_uid we need to examine the "func" of
      # the partial function.
      checkpoint_restore_uid = kwargs[
          "initial_value"].func.checkpoint_position.restore_uid
    else:
      checkpoint_restore_uid = None

    created = self._create_variable(next_creator, **kwargs)

    if checkpoint_restore_uid is not None:
      # Let the checkpointing infrastructure know that the variable was
      # already restored so it doesn't waste memory loading the value again.
      # In this case of CheckpointInitialValueCallable this may already be
      # done by the final variable creator, but it doesn't hurt to do it
      # again.
      created._maybe_initialize_trackable()
      created._update_uid = checkpoint_restore_uid
    return created

  def distributed_getter(getter, *args, **kwargs):
    return getter(*args, **kwargs)

  # 这里使用了 creator_with_resource_vars
  return _CurrentDistributionContext(
      strategy,
      variable_scope.variable_creator_scope(creator_with_resource_vars), # 配置如何建立变量
      variable_scope.variable_scope(
          variable_scope.get_variable_scope(),
          custom_getter=distributed_getter), self._default_device)

逻辑如下,进入scope之后经过一系列操作之后,返回了_CurrentDistributionContext,其内部又会有一系列操作,我们继续看看。

图 14 如何创建变量

1.3.3 _CurrentDistributionContext

_CurrentDistributionContext 维护了策略相关的信息,设置各种作用域,返回策略。

class _CurrentDistributionContext(object):
  """Context manager setting the current  tf.distribute.Strategy .

  Also: overrides the variable creator and optionally the current device.
  """

  def __init__(self,
               strategy,
               var_creator_scope,
               var_scope=None,
               resource_creator_scope=None,
               default_device=None):
    self._context = distribution_strategy_context._CrossReplicaThreadMode( 
        strategy)
    self._var_creator_scope = var_creator_scope
    self._var_scope = var_scope
    self._resource_creator_scope = resource_creator_scope
    if default_device:
      self._device_scope = ops.device(default_device)
    else:
      self._device_scope = None
    self._same_scope_again_count = 0

  def __enter__(self):
    # Allow this scope to be entered if this strategy is already in scope.
    if distribution_strategy_context.has_strategy():
      _require_cross_replica_or_default_context_extended(
          self._context.strategy.extended)
      self._same_scope_again_count += 1
    else:
      _push_per_thread_mode(self._context)
      if self._var_scope:
        self._var_scope.__enter__()
      self._var_creator_scope.__enter__()
      if self._resource_creator_scope:
        nest.map_structure(lambda scope: scope.__enter__(),
                           self._resource_creator_scope)
      if self._device_scope:
        self._device_scope.__enter__()
    return self._context.strategy

  def __exit__(self, exception_type, exception_value, traceback):
    if self._same_scope_again_count > 0:
      self._same_scope_again_count -= 1
      return
    if self._device_scope:
      try:
        self._device_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Device scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of  with  scope."),
            e)

    try:
      self._var_creator_scope.__exit__(
          exception_type, exception_value, traceback)
    except RuntimeError as e:
      six.raise_from(
          RuntimeError("Variable creator scope nesting error: move call to "
                       "tf.distribute.set_strategy() out of  with  scope."),
          e)

    if self._resource_creator_scope:
      try:
        if isinstance(self._resource_creator_scope, list):
          reversed_resource_creator_scope = self._resource_creator_scope[::-1]
          nest.map_structure(
              lambda scope: scope.__exit__(exception_type, exception_value,  
                                           traceback),
              reversed_resource_creator_scope)

        else:
          self._resource_creator_scope.__exit__(exception_type, exception_value,
                                                traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Resource creator scope nesting error: move call "
                         "to tf.distribute.set_strategy() out of  with  "
                         "scope."), e)

    if self._var_scope:
      try:
        self._var_scope.__exit__(exception_type, exception_value, traceback)
      except RuntimeError as e:
        six.raise_from(
            RuntimeError("Variable scope nesting error: move call to "
                         "tf.distribute.set_strategy() out of  with  scope."),
            e)
    _pop_per_thread_mode()

1.3.4 MirroredStrategy

有了上面的分析,我们可以知道,当使用了 Strategy 时候,会使用 Strategy 的 _create_variable 最终生成变量。

create_variable 负责具体业务。里面会用到 self._devices,然后调用到了 distribute_utils.create_mirrored_variable,其会使用 real_mirrored_creator,VARIABLE_CLASS_MAPPING 和 create_mirrored_variable 来建立变量。real_mirrored_creator会配置具体的变量名称,后续调用则会据此来设定变量应该放到哪个设备之上。对于第一个设备,这里依然采用原来的名字,而后续设备则在原变量名之后加上 /replica_设备号 ,这样就可以和原始变量区别。接着会把原来变量的值赋值给这些对应的副本变量。

def _create_variable(self, next_creator, **kwargs):
  """Create a mirrored variable. See  DistributionStrategy.scope ."""
  colocate_with = kwargs.pop("colocate_with", None)
  if colocate_with is None:
    devices = self._devices
  elif isinstance(colocate_with, numpy_dataset.SingleDevice):
    with ops.device(colocate_with.device):
      return next_creator(**kwargs)
  else:
    devices = colocate_with._devices  

  def _real_mirrored_creator(**kwargs):  
    value_list = []
    for i, d in enumerate(devices):
      with ops.device(d):
        kwargs["initial_value"] = self._get_variable_creator_initial_value(
            replica_id=i,
            device=d,
            primary_var=value_list[0] if value_list else None,
            **kwargs)
        if i > 0:
          # Give replicas meaningful distinct names:
          var0name = value_list[0].name.split(":")[0]
          # We append a / to variable names created on replicas with id > 0 to
          # ensure that we ignore the name scope and instead use the given
          # name as the absolute name of the variable.
          kwargs["name"] = "%s/replica_%d/" % (var0name, i)
        with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
          # Don't record operations (e.g. other variable reads) during
          # variable creation.
          with tape.stop_recording():
            v = next_creator(**kwargs)
        assert not isinstance(v, values.DistributedVariable)
        value_list.append(v)
    return value_list

  return distribute_utils.create_mirrored_variable(
      self._container_strategy(), _real_mirrored_creator,
      distribute_utils.VARIABLE_CLASS_MAPPING,
      distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs)

VARIABLE_CLASS_MAPPING 用来设定生成哪种类型的变量。VARIABLE_POLICY_MAPPING 设定使用何种策略来应对读写同步。

# The following mapping indicates the policy that you must use for a given
# variable  synchronization  and  aggregation  pair.
# OnWritePolicy is used for:
# (synchronization=Auto, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# (synchronization=ON_WRITE, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
# OnReadPolicy is used for:
# (synchronization=ON_READ, aggregation=NONE,SUM,MEAN,ONLY_FIRST_REPLICA)
VARIABLE_POLICY_MAPPING = {
    vs.VariableSynchronization.ON_WRITE: values_lib.OnWritePolicy,
    vs.VariableSynchronization.ON_READ: values_lib.OnReadPolicy,
}

VARIABLE_CLASS_MAPPING = {
    "VariableClass": values_lib.DistributedVariable,
    vs.VariableSynchronization.ON_WRITE: values_lib.MirroredVariable, # 我们关注这里
    vs.VariableSynchronization.ON_READ: values_lib.SyncOnReadVariable,
}

1.3.5 distribute_utils

tensorflow/python/distribute/distribute_utils.py 的 create_mirrored_variable 会具体建立变量。对于我们的例子,class_mapping 就是 values_lib.MirroredVariable。

def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
                             policy_mapping, **kwargs):
  """Create distributed variables with given synchronization and aggregation."""
  # Figure out what collections this variable should be added to.
  # We'll add the MirroredVariable to those collections instead.
  var_collections = kwargs.pop("collections", None)
  if var_collections is None:
    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  synchronization = _validate_synchronization(kwargs)
  # Update synchronization in kwargs in case it's AUTO, which is converted to
  # ON_WRITE.
  kwargs["synchronization"] = synchronization
  aggregation = _validate_aggregation(kwargs)
  use_var_policy = getattr(strategy.extended, "_use_var_policy", False)

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  with tape.stop_recording():
    # 构建镜像变量列表    
    value_list = real_mirrored_creator(**kwargs)
    # MirroredVariable is recreated during saved_model loading, and its
    # component variables (value_list) will have None initializer. We
    # set their initializers to no_op so that consumer like
    #  global_variables_initializer  wouldn't complain, as it groups all
    # variables' initializers thus all variables have to have initializers.
    for v in value_list:
      if hasattr(v, "_initializer_op") and v._initializer_op is None:
        v._initializer_op = control_flow_ops.no_op()
    if use_var_policy:
      # 获取策略,得到类,生成变量
      var_policy_cls = policy_mapping.get(synchronization)
      var_policy = var_policy_cls(aggregation=aggregation)
      var_cls = class_mapping.get("VariableClass")
      result = var_cls(strategy, value_list, aggregation, var_policy=var_policy)
    else:
      var_cls = class_mapping.get(synchronization)
      result = var_cls(strategy, value_list, aggregation)

  # Add the wrapped variable to the requested collections.
  # The handling of eager mode and the global step matches
  # ResourceVariable._init_from_args().
  if not context.executing_eagerly():
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for value in value_list:
        for i, trainable_variable in enumerate(l):
          if value is trainable_variable:
            del l[i]
            break

    g.add_to_collections(var_collections, result)
  elif ops.GraphKeys.GLOBAL_STEP in var_collections:
    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

  return result

最终构建逻辑如下,_CurrentDistributionContext 成员函数 _var_creator_scope 会指向 creator_with_resource_vars。当生成变量时候,调用时候 creator_with_resource_vars 会逐层调用,最后生成 MirroredVariable。

图 15 创建变量

1.4 总结

前面的问题我们目前为止回答如下:

  • 如何调用到 Strategy 这里?
    • 读写变量最终都会落到 strategy 或者 strategy.extended 之上。
  • 如何生成 Mirrored Variable?
    • 用户在 scope 之中会获得上下文,上下文提供了建立变量的方法,用户在上下文之中建立的变量自然就是 Mirrored Variable。
  • 如何把张量分发到各个设备上?
    • 当使用了 Strategy 时候,会使用 Strategy 的 _create_variable 生成变量。 _create_variable 最终调用到 _real_mirrored_creator 。
    • _real_mirrored_creator 会配置具体的变量名称,后续调用则会据此来设定变量应该放到哪个设备之上。对于第一个设备,这里依然采用原来的名字,而后续设备则在原变量名之后加上 /replica _设备号 ,这样就可以和原始变量区别。
    • 后续在布局(placement)时候,会根据设备名字进行分配,把变量放置到对应设备之上。
  • 如果对外保持一个统一的视图?
    • 在上下文之中,用户得到的是 Mirrored Variable, Mirrored Variable 对外屏蔽了内部变量,提供了统一视图。比如:读取时候,会调用 _get_cross_replica,其内部调用 Policy。而 Policy 会调用 distribute_strategy 完成规约。
  • 变量之间如何保持一致?
    • 在前面 scatter_update 分析时候知道,更新变量时候,会调用到 strategy.extended 之上,在 strategy.extended 中,变量之间通过例如 All-Reduce 来保持一致,这个我们后文会详细分析。

用示例图来演示下,假设有一个 MirroredVariable A 变量,其内部是由 3 个张量组成。每个 Worker 都觉得自己在更新 MirroredVariable A,实际上是分别更新不同的变量,变量之间通过例如 All-Reduce 来保持一致。

图 16 如何更新

2. ShardedVariable

在机器学习训练之中,如果变量太大,无法放入单个设备上(例如大型embedding),则可能需要在多个设备上对这个变量进行分片。在 TensorFlow 中,与这个思想对应的概念就是 ShardedVariable 。

图 17 ShardedVariable

变量分片(Variable sharding)是指将一个变量分割成多个较小的变量,这些变量被称为分片(shards)。ShardedVariable 可以被看做是一个容器,容器中的 "变量 "应被视为分片。ShardedVariable 类维护一个可以独立存储在不同设备(例如,多个参数服务器)上的较小变量的列表,并负责保存和恢复这些变量,就像它们是一个较大的变量一样。变量分片对于缓解分配访问这些分片时的网络负载很有用,它对于在多个参数服务器上分配一个普通变量的计算和存储也很有用。

图 18 ShardedVariable 容器

ShardedVariable 类的对象可以用给定数量的分片进行保存,然后从检查点恢复到不同数量的分片。SavedModel可以被 TF serving API 等程序使用,但是不支持 tf.saved_model.load 。由于 ShardedVariable 可以被保存,然后根据恢复环境恢复到不同数量的分片,例如,TF serving API 会恢复到只有一个分片以提高服务效率,所以当在tf.function 中使用 ShardedVariable 时,一般不应假设它在保存和加载时具有相同数量的分片。

2.1 问题

对于 ShardedVariable,我们依然用几个问题来引导分析。

  • 如何实现参数存到参数服务器之上?
  • 如何对参数实现分片存储?
  • 如何把计算(梯度更新参数的操作)放到参数服务器之上?(会在后续章节进行分析)
  • Coordinator 是随机分配计算的吗?(会在后续章节进行分析)

2.2 定义

ShardedVariable 的定义其实没有太多内容,主要精华都在基类 ShardedVariableMixin 之中,我们稍后就会进行分析。

图 19 ShardedVariable 定义

具体定义代码如下:

class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
  """A container for  Variables  that should be treated as shards.
  """

  @property
  def _type_spec(self):
    return ShardedVariableSpec(
        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
          for v in self._variables))

  @classmethod
  def _overload_all_operators(cls):
    """Register overloads for all operators."""
    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
      if operator == '__getitem__':
        continue

      cls._overload_operator(operator)

  @classmethod
  def _overload_operator(cls, operator):
    """Delegate an operator overload to  ops.Tensor ."""
    tensor_operator = getattr(ops.Tensor, operator)

    def _operator(v, *args, **kwargs):
      return tensor_operator(_var_to_tensor(v), *args, **kwargs)

    setattr(cls, operator, _operator)

2.3 如何分区

ShardedVariable 的精华之一就是分区,我们探究一下其机理。需要注意的是:ShardedVariable 只支持在第一个维度进行分区。

2.3.1 基类

基类 Partitioner 没有太多东西,其派生类需要实现 call

@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
  """Partitioner base class: all partitiners inherit from this class.

  Partitioners should implement a  __call__  method with the following
  signature:

  ```python
  def __call__(self, shape, dtype, axis=0):
    # Partitions the given  shape  and returns the partition results.
    # See docstring of  __call__  method for the format of partition results.
  ```
  """

  def __call__(self, shape, dtype, axis=0):
    """Partitions the given  shape  and returns the partition results.

    Examples of a partitioner that allocates a fixed number of shards:

    ```python
    partitioner = FixedShardsPartitioner(num_shards=2)
    partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
    print(partitions) # [2, 0]
    ```

    Args:
      shape: a  tf.TensorShape , the shape to partition.
      dtype: a  tf.dtypes.Dtype  indicating the type of the partition value.
      axis: The axis to partition along.  Default: outermost axis.

    Returns:
      A list of integers representing the number of partitions on each axis,
      where i-th value correponds to i-th axis.
    """
    raise NotImplementedError

2.2.4 固定分区

FixedShardsPartitioner 会把变量分成固定的分片。注释之中有一个使用样例,对于本例来说,axis = 0 时候,min(self._num_shards, shape.dims[axis].value) = min(2, 10),所以分成两个 shard。

@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
  """Partitioner that allocates a fixed number of shards.

  Examples:

  >>> # standalone usage:
  >>> partitioner = FixedShardsPartitioner(num_shards=2)
  >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
  >>> [2, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)

  """

  def __init__(self, num_shards):
    """Creates a new  FixedShardsPartitioner .

    Args:
      num_shards:  int , number of shards to partition.
    """
    self._num_shards = num_shards

  def __call__(self, shape, dtype, axis=0):
    del dtype
    result = [1] * len(shape)
    result[axis] = min(self._num_shards, shape.dims[axis].value)
    return result

2.2.5 最小分区

MinSizePartitioner 为每个分片分配最小尺寸的分区器。该分区器确保每个分片至少有"min_shard_字节",并尝试分配尽可能多的分片,即保持分片大小尽可能小。此类分片的最大数量(上限)由"max_Shard"给出。

@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
  """Partitioner that allocates a minimum size per shard.

  This partitioner ensures each shard has at least  min_shard_bytes , and tries
  to allocate as many shards as possible, i.e., keeping shard size as small as
  possible. The maximum number of such shards (upper bound) is given by
   max_shards .

  Examples:

  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [2, 1]
  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [6, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
  """

  def __init__(self,
               min_shard_bytes=256 << 10,
               max_shards=1,
               bytes_per_string=16):
    """Creates a new  MinSizePartitioner .

    Args:
      min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
      max_shards: Upper bound on the number of shards. Defaults to 1.
      bytes_per_string: If the partition value is of type string, this provides
        an estimate of how large each string is.
    """
    self._min_shard_bytes = min_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0):
    return partitioned_variables.min_max_variable_partitioner(
        max_partitions=self._max_shards,
        axis=axis,
        min_slice_size=self._min_shard_bytes,
        bytes_per_string_element=self._bytes_per_string)(shape, dtype)

min_max_variable_partitioner 是具体业务实现。该方法返回一个分区器,该分区器对"给定形状和数据类型"的变量进行分区,使每个分区有的最小值为 min_slice_size 大小的切片。此类分区的最大数量(上限)由 max_partitions 给出。

@tf_export(v1=["min_max_variable_partitioner"])
def min_max_variable_partitioner(max_partitions=1, axis=0,
                                 min_slice_size=256 << 10,
                                 bytes_per_string_element=16):
  """Partitioner to allocate minimum size per slice.

  Returns a partitioner that partitions the variable of given shape and dtype
  such that each partition has a minimum of  min_slice_size  slice of the
  variable. The maximum number of such partitions (upper bound) is given by
   max_partitions .

  Args:
    max_partitions: Upper bound on the number of partitions. Defaults to 1.
    axis: Axis along which to partition the variable. Defaults to 0.
    min_slice_size: Minimum size of the variable slice per partition. Defaults
      to 256K.
    bytes_per_string_element: If the  Variable  is of type string, this provides
      an estimate of how large each scalar in the  Variable  is.

  Returns:
    A partition function usable as the  partitioner  argument to
     variable_scope  and  get_variable .

  """
  def _partitioner(shape, dtype):
    """Partitioner that partitions list for a variable of given shape and type.

    Ex: Consider partitioning a variable of type float32 with
      shape=[1024, 1024].
      If  max_partitions  >= 16, this function would return
        [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
      If  max_partitions  < 16, this function would return
        [ max_partitions , 1].

    Args:
      shape: Shape of the variable.
      dtype: Type of the variable.

    Returns:
      List of partitions for each axis (currently only one axis can be
      partitioned).

    Raises:
      ValueError: If axis to partition along does not exist for the variable.
    """
    if axis >= len(shape):
      raise ValueError("Can not partition variable along axis %d when shape is "
                       "only %s" % (axis, shape))
    if dtype.base_dtype == dtypes.string:
      bytes_per_element = bytes_per_string_element
    else:
      bytes_per_element = dtype.size
    total_size_bytes = shape.num_elements() * bytes_per_element
    partitions = total_size_bytes / min_slice_size
    partitions_list = [1] * len(shape)
    # We can not partition the variable beyond what its shape or
    #  max_partitions  allows.
    partitions_list[axis] = max(1, min(shape.dims[axis].value,
                                       max_partitions,
                                       int(math.ceil(partitions))))
    return partitions_list
  return _partitioner

2.3.4 最大分区

此分区器确保每个碎片最多有 max_shard_bytes 大的尺寸,并尝试分配尽可能少的分片,即保持分片尽可能大。如果分区程序达到了 max_shard 限制,那么每个 shard 可能最终都会大于 max_shard_bytes。默认情况下,max_shards..等于 None,就是不限制分片的数量。

@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
  """Partitioner that keeps shards below  max_shard_bytes .

  This partitioner ensures each shard has at most  max_shard_bytes , and tries
  to allocate as few shards as possible, i.e., keeping shard size as large
  as possible.

  If the partitioner hits the  max_shards  limit, then each shard may end up
  larger than  max_shard_bytes . By default  max_shards  equals  None  and no
  limit on the number of shards is enforced.

  Examples:

  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [6, 1]
  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [2, 1]
  >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
  >>> [1, 1]
  >>>
  >>> # use in ParameterServerStrategy
  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
  """

  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
    """Creates a new  MaxSizePartitioner .

    Args:
      max_shard_bytes: The maximum size any given shard is allowed to be.
      max_shards: The maximum number of shards in  int  created taking
        precedence over  max_shard_bytes .
      bytes_per_string: If the partition value is of type string, this provides
        an estimate of how large each string is.
    """
    if max_shard_bytes < 1:
      raise ValueError('max_shard_bytes must be positive, got: %r' %
                       max_shard_bytes)
    if max_shards and max_shards < 1:
      raise ValueError('max_shards must be positive, got: %r' % max_shards)
    if bytes_per_string < 1:
      raise ValueError('bytes_per_string must be positive, got: %r' %
                       bytes_per_string)

    self._max_shard_bytes = max_shard_bytes
    self._max_shards = max_shards
    self._bytes_per_string = bytes_per_string

  def __call__(self, shape, dtype, axis=0):
    return partitioned_variables.variable_axis_size_partitioner(
        max_shard_bytes=self._max_shard_bytes,
        max_shards=self._max_shards,
        bytes_per_string_element=self._bytes_per_string,
        axis=axis)(shape, dtype)

variable_axis_size_partitioner 是具体业务功能。此分区程序将沿一个轴切分一个变量,试图将最大分片的大小保持在 max_shard_bytes 以下。如果分区程序达到了 max_shard 限制,那么每个 shard 可能最终都会大于 max_shard_bytes。默认情况下,max_shards 等于 None,意思是不限制碎片的数量。

max_shard_bytes 的一个合理值是(64<<20)-1,或者在 64MB 左右,这样可以保证低于 protobuf 字节的限制。

@tf_export(v1=["variable_axis_size_partitioner"])
def variable_axis_size_partitioner(
    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
  """Get a partitioner for VariableScope to keep shards below  max_shard_bytes .

  This partitioner will shard a Variable along one axis, attempting to keep
  the maximum shard size below  max_shard_bytes .  In practice, this is not
  always possible when sharding along only one axis.  When this happens,
  this axis is sharded as much as possible (i.e., every dimension becomes
  a separate shard).

  If the partitioner hits the  max_shards  limit, then each shard may end up
  larger than  max_shard_bytes . By default  max_shards  equals  None  and no
  limit on the number of shards is enforced.

  One reasonable value for  max_shard_bytes  is  (64 << 20) - 1 , or almost
   64MB , to keep below the protobuf byte limit.

  Args:
    max_shard_bytes: The maximum size any given shard is allowed to be.
    axis: The axis to partition along.  Default: outermost axis.
    bytes_per_string_element: If the  Variable  is of type string, this provides
      an estimate of how large each scalar in the  Variable  is.
    max_shards: The maximum number of shards in int created taking precedence
      over  max_shard_bytes .

  Returns:
    A partition function usable as the  partitioner  argument to
     variable_scope  and  get_variable .

  Raises:
    ValueError: If any of the byte counts are non-positive.
  """

  def _partitioner(shape, dtype):
    """Partitioner that partitions shards to have max_shard_bytes total size.

    Args:
      shape: A  TensorShape .
      dtype: A  DType .

    Returns:
      A tuple representing how much to slice each axis in shape.

    Raises:
      ValueError: If shape is not a fully defined  TensorShape  or dtype is not
        a  DType .
    """
    if dtype.base_dtype == dtypes.string:
      element_size = bytes_per_string_element
    else:
      element_size = dtype.size

    partitions = [1] * shape.ndims
    bytes_per_slice = 1.0 * (
        shape.num_elements() / shape.dims[axis].value) * element_size
    # How many slices can we fit on one shard of size at most max_shard_bytes?
    # At least one slice is required.
    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
    # How many shards do we need for axis given that each shard fits
    # slices_per_shard slices from a total of shape[axis] slices?
    axis_shards = int(math.ceil(
        1.0 * shape.dims[axis].value / slices_per_shard))
    if max_shards:
      axis_shards = min(max_shards, axis_shards)

    partitions[axis] = axis_shards

    return partitions

  return _partitioner

2.4 ShardedVariableMixin

前面提到了,ShardedVariableMixin 是核心所在,我们接下来就分析一下。ShardedVariableMixin 主要成员变量是:

  • _variables : 分区的变量。

  • _var_offsets : 分区变量在 ShardedVariableMixin 对应的偏移,就是把 _variables 看成是一个整体,然后用 offset 在其中查找对应的数据。

  • _shape : ShardedVariableMixin 的 shape。

  • _name : ShardedVariableMixin 的名字。

class ShardedVariableMixin(trackable.Trackable):
  """Mixin for ShardedVariable."""

  def __init__(self,
               variables: Sequence[variables_lib.Variable],
               name='ShardedVariable'):
    """Treats  variables  as shards of a larger Variable.

    Args:
      variables: A list of  ResourceVariable s that comprise this sharded
        variable. Variables should not be shared between different
         ShardedVariableMixin  objects.
      name: String. Name of this container. Defaults to "ShardedVariable".
    """
    super(ShardedVariableMixin, self).__init__()
    self._variables = variables
    self._name = name

    var_dtypes = {v.dtype for v in variables}
    first_var = variables[0]
    self._dtype = first_var.dtype

    # All variables must have the same shape for axes > 0.
    # 计算整体形状
    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
    first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
    self._shape = tensor_shape.TensorShape([first_dim] +
                                           first_var.shape.as_list()[1:])
    
    # 计算每个分区在整体之中的偏移
    self._var_offsets = [
        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
    ]
    for i in range(1, len(variables)):
      # Always partition on the first axis. Offsets on other axes are 0.
      self._var_offsets[i][0] += (
          self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])

    save_slice_info = [v._get_save_slice_info() for v in variables]  

    # We create an uninitialized saving_variable with the full shape, which can
    # be later captured in signatures so that the signatures can treat this
    # ShardedVariable as one single variable.
    self._saving_variable = resource_variable_ops.UninitializedVariable(
        shape=self._shape, dtype=self._dtype, name=self._name)

2.4.1 使用

我们用如下示例看看如何使用。

variables = [
  tf.Variable(np.array([[3, 2]]), shape=(1, 2), dtype=tf.float32,),
  tf.Variable(np.array([[3, 2], [0, 1]]),  shape=(2, 2), dtype=tf.float32),
  tf.Variable(np.array([[3, 2]]),  shape=(1, 2), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)

sharded_variable 内部成员变量打印如下,可以看到,_var_offsets 就是把所有参数分区看为是一个整体,从中找到对应的分区。

_shape = {TensorShape: 2} (4, 2)
_var_offsets = {list: 3} [[0, 0], [1, 0], [3, 0]]
first_dim = {int} 4

比如上面例子之中,三个变量整体打包之后就是如下所示,用户可以使用 offset 在这里查找数据。

[[3,2][3,2],[0,1],[3,2]]

我们再用另一个图例看看。假设参数有4个分区,则具体如下:

图 20 分区

如果变量都放在参数服务器上,则具体如下。

图 21 分区与参数服务器

2.4.2 获取分区

我们接下来看看如何获取分区。就是从 sharded variable 之中把指定部分作为一个张量取出。具体逻辑是:分析传入的 spec, 根据 spec 的内容对 sharded variable 进行处理,获得一个参数分区。

  def __getitem__(self, slice_spec):
    """Extracts the specified region as a Tensor from the sharded variable.

    The API contract is identical to  Tensor.__getitem__ . Assignment to the
    sliced range is not yet supported.

    Args:
      slice_spec: The arguments to __getitem__, specifying the global slicing of
        the sharded variable.

    Returns:
      The appropriate slice of tensor based on  slice_spec .

    Raises:
      IndexError: If a slice index is out of bound.
      TypeError: If  spec_spec  contains Tensor.
    """

    # 拿到分区 spec
    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
                                         slice_spec.dtype == dtypes.bool) or
        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
      tensor = _var_to_tensor(self)
      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)

    if not isinstance(slice_spec, (list, tuple)):
      slice_spec = (slice_spec,)

    s = slice_spec[0]
    if isinstance(s, slice):
      # 如果是 slice 类型,则解析分区
      first_dim_slice_specs = self._decompose_slice_spec(s)
      values = []
      for i, var in enumerate(self._variables):
        if first_dim_slice_specs[i] is not None:
          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
          values.append(var[all_dim_slice_spec])
      if s.step is not None and s.step < 0:
        values.reverse()
      if not values:
        return constant_op.constant([],
                                    dtype=self._dtype,
                                    shape=((0,) + self._shape[1:]))
      return array_ops.concat(values, axis=0)
    elif s is Ellipsis:
      return array_ops.concat([var[slice_spec] for var in self._variables],
                              axis=0)
    elif s is array_ops.newaxis:
      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
                              axis=0)[array_ops.newaxis]
    else:
      if isinstance(s, ops.Tensor):
        raise TypeError(
            'ShardedVariable: using Tensor for indexing is not allowed.')
      if s < 0:
        s += self._shape[0]
        
      # 在参数分区之中遍历,用offset来提取数据
      for i in range(len(self._variables)):
        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
                                             s < self._var_offsets[i + 1][0]):
          return self._variables[i][(s - self._var_offsets[i][0],) +
                                    slice_spec[1:]]

Spec 一般来说是什么样式?下面示例讲述的比较清晰。

    For example, given component variables:
      v0 = [0, 1, 2]
      v1 = [3, 4, 5]
      v2 = [6, 7, 8, 9]

    If  slice_spec  is slice(start=None, stop=None, step=None), we will have:
      v0[returned[0]] = [0, 1, 2]
      v1[returned[1]] = [3, 4, 5]
      v2[returned[2]] = [6, 7, 8, 9]
    If  slice_spec  is slice(start=2, stop=8, step=3), we will have:
      v0[returned[0]] = [2]
      v1[returned[1]] = [5]
      returned[2] == None
    If  slice_spec  is slice(start=9, stop=3, step=-2), we will have:
      returned[0] == None
      v1[returned[1]] = [5]
      v2[returned[2]] = [9, 7]

获取/解析 spec 的代码具体如下:

  def _decompose_slice_spec(self, slice_spec):
    """Decompose a global slice_spec into a list of per-variable slice_spec.

     ShardedVariable  only supports first dimension partitioning, thus
     slice_spec  must be for first dimension.

    Args:
      slice_spec: A python  slice  object that specifies the global slicing.

    Returns:
      A list of python  slice  objects or None specifying the local slicing for
      each component variable. None means no slicing.

    """
    result = []
    # Normalize start, end and stop.
    slice_step = slice_spec.step if slice_spec.step is not None else 1
    if slice_step == 0:
      raise ValueError('slice step cannot be zero')
    slice_start = slice_spec.start
    if slice_start is None:
      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
    elif slice_start < 0:
      slice_start += self._shape[0]
    slice_end = slice_spec.stop
    if slice_end is None:
      # After the normalization, we no longer interpret negative index, thus
      # "-1" conceptually refers to the element before the first one, which
      # doesn't exist. This is to ease the decomposition code.
      slice_end = self._shape[0] if slice_step > 0 else -1
    elif slice_end < 0:
      slice_end += self._shape[0]

    # To find the local slice_spec of each component variable, we start from
    # the start of the global slice, and iterate through each variable.
    # When iterating on a variable, we move the cursor ( cur ) to the first
    # index that falls into the variable's range, which becomes the start of
    # the variable's local slice_spec. The end of the local_spec is determined
    # by using whatever is smaller between global slice end and variable range
    # end.
    cur = slice_start
    if slice_step > 0:
      for i in range(len(self._var_offsets)):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1][0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur < var_start:
          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
        if cur >= var_end or cur >= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          end = min(slice_end, var_end) - var_start
          result.append(slice(start, end, slice_step))
    else:  # slice_step < 0
      for i in range(len(self._var_offsets) - 1, -1, -1):
        var_start = self._var_offsets[i][0]
        var_end = (
            self._var_offsets[i + 1][0]
            if i < len(self._var_offsets) - 1 else self._shape[0])
        if cur >= var_end:
          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
        if cur < var_start or cur <= slice_end:
          result.append(None)
        else:
          start = cur - var_start
          if slice_end >= var_start:
            end = slice_end - var_start
          else:
            end = None  # no explicit end: slice until hitting the boundary.
          result.append(slice(start, end, slice_step))

      result.reverse()

    return result

2.4.3 Embedding

接下来我们看看嵌入的查找。可以发现这里就是调用时候添加了对应的 partition_strategy,name, validate_indices, max_norm 等信息,然后传递给embedding_ops.embedding_lookup。这里分区策略是 'mod'。

# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
                     ids,
                     partition_strategy='mod',
                     name=None,
                     validate_indices=True,
                     max_norm=None):
  if isinstance(params, list):
    params = params[0]
  return embedding_ops.embedding_lookup(params.variables, ids,
                                        partition_strategy, name,
                                        validate_indices, max_norm)

流程来到 embedding_lookup(tensorflow/python/ops/embedding_ops.py),我们需要继续看 _embedding_lookup_and_transform。

@tf_export(v1=["nn.embedding_lookup"])
@dispatch.add_dispatch_support
def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True,  # pylint: disable=unused-argument
    max_norm=None):
  """Looks up embeddings for the given  ids  from a list of tensors.

  This function is used to perform parallel lookups on the list of tensors in
   params .  It is a generalization of  tf.gather , where  params  is
  interpreted as a partitioning of a large embedding tensor.   params  may be
  a  PartitionedVariable  as returned by using  tf.compat.v1.get_variable() 
  with a partitioner.

  If  len(params) > 1 , each element  id  of  ids  is partitioned between
  the elements of  params  according to the  partition_strategy .
  In all strategies, if the id space does not evenly divide the number of
  partitions, each of the first  (max_id + 1) % len(params)  partitions will
  be assigned one more id.

  If the input ids are ragged tensors, partition variables are not supported and
  the partition strategy and the max_norm are ignored.
  The results of the lookup are concatenated into a dense
  tensor. The returned tensor has shape  shape(ids) + shape(params)[1:] .

  Args:
    params: A single tensor representing the complete embedding tensor, or a
      list of P tensors all of same shape except for the first dimension,
      representing sharded embedding tensors.  Alternatively, a
       PartitionedVariable , created by partitioning along dimension 0. Each
      element must be appropriately sized for the given  partition_strategy .
    ids: A  Tensor  or a 'RaggedTensor' with type  int32  or  int64  containing
      the ids to be looked up in  params .
    partition_strategy: A string specifying the partitioning strategy, relevant
      if  len(params) > 1 . Currently  "div"  and  "mod"  are supported. Default
      is  "mod" .
    name: A name for the operation (optional).
    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
      in  indices  are always validated to be within range.  If assigned to GPU,
      out-of-bound indices result in safe but unspecified behavior, which may
      include raising an error.
    max_norm: If not  None , each embedding is clipped if its l2-norm is larger
      than this value.

  Returns:
    A  Tensor  or a 'RaggedTensor', depending on the input, with the same type
    as the tensors in  params .

  Raises:
    ValueError: If  params  is empty.
  """
  if isinstance(ids, ragged_tensor.RaggedTensor):
    return embedding_lookup_ragged(params, ids,
                                   partition_strategy=partition_strategy,
                                   max_norm=max_norm,
                                   name=name)

  return _embedding_lookup_and_transform(
      params=params,
      ids=ids,
      partition_strategy=partition_strategy,
      name=name,
      max_norm=max_norm,
      transform_fn=None)

_embedding_lookup_and_transform 这里是具体如何分区的代码,我们先用实例演示一下。

  • 如果 "partition_strategy "是 "mod",我们将每个id分配给分区 p = id % len(params) 。例如。
    13个ID被分割到5个分区中,结果如下: [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]
  • 如果 "partition_strategy "是 "div",我们会以连续的方式将ID分配给分区。在这个例子中,13个ID被分成5个分区,结果如下: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] 。

具体代码如下:

def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None):
  """Helper function for embedding_lookup and _compute_sampled_logits.

  This function is a generalization of embedding_lookup that optionally
  applies a caller-specified transformation to each embedding. This is
  done through the  transform_fn  argument. If provided, the function is
  applied to each partitioned tensor of retrieved embeddings, colocated
  with the embeddings. This function will be called with a single  Tensor 
  argument of the same type as the  params  tensor and should return a
   Tensor . The shape of the argument will be the same as  params  except
  for the size of the first dimension. The first dimension of the result's
  shape must be the same size as the argument's.

  Args:
    params: See embedding_lookup.
    ids: See embedding_lookup.
    partition_strategy: See embedding_lookup.
    name: See embedding_lookup.
    max_norm: See embedding_lookup.
    transform_fn: An optional function to apply to each retrieved embedding. If
      max_norm is provided, transform_fn is applied to the norm-limited
      embeddings.

  Returns:
    See embedding_lookup for details.
  Raises:
    ValueError: If  params  is empty.
  """

  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    # 省略代码
    else:
      # Flatten the ids. There are two cases where we need to do this.
      # - There is more than one params tensor.
      # - There is a transform_fn and ids is not statically known to be 1-D.
      #   We must flatten in this case because transform_fn expects a flat
      #   tensor of embeddings.
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = tensor_shape.Dimension(
            tensor_shape.dimension_value(params[0].get_shape()[0]))
        for p in xrange(1, np):
          dim_0_size += tensor_shape.Dimension(
              tensor_shape.dimension_value(params[p].get_shape()[0]))
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
            if param_p_dim is not None:
              dim_0_sizes.append(param_p_dim)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
                                         (flat_ids - extras) //
                                         ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        new_ids = array_ops.where(p_assignments < extras,
                                  flat_ids % (ids_per_partition + 1),
                                  (flat_ids - extras) % ids_per_partition)
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

  # 省略其他代码

如何使用 embedding?我们从注释之中提取使用方法如下,这里构建了一个 ShardedVariable,模型通过 embedding_lookup 来对此变量进行操作。

  >>> class Model(tf.Module):
  ...   def __init__(self):
  ...     self.sharded_variable = ShardedVariable([
  ...       tf.Variable([3.0], dtype=tf.float32),
  ...       tf.Variable([2.0], dtype=tf.float32)
  ...     ])
  ...
  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
  ...   def fn(self, x):
  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
  ...
  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
  ...   def serve_fn(self, x):
  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
  >>>
  >>> model = Model()
  >>> model.fn(1).numpy()
  2.0
  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
  ...   signatures=model.serve_fn)

如果用图例表示,则下面 worker 会在两个参数服务器上并行操作来提取 embedding。

图 22 处理 embedding

2.5 构建

关于 ShardedVariable 的构建,我们直接看 ParameterServerStrategyV2 之中的构建过程。

2.5.1 变量分片

要启用变量分片,你可以在构建 ParameterServerStrategy 对象时传入一个 variable_partitioner。每次创建变量时,variable_partitioner 都会被调用,并希望它能沿变量的每个维度返回分片的数量。系统提供了一些开箱即用的 variable_partitioner,比如 tf.distribution.experimental.partitioners.MinSizePartitioner 。建议使用基于大小(size-based)的分区器,如 tf.distribution.experimental.partitioners.MinSizePartitioner ,以避免对小变量进行分区,因为那样可能对模型训练速度产生负面影响。

当传入 variable_partitioner 时候,如果你直接在 strategy.scope() 下创建一个变量,它将成为一个具有 variables 属性(property)的容器类型,此属性将提供对分片列表的访问。在大多数情况下,这个容器将通过连接(concatenating)所有的分片自动转换为一个张量。因此,它可以作为一个正常的变量使用。另一方面,一些TensorFlow方法,如 tf.nn.embedding_lookup 为这种容器类型提供了有效的实现,这些方法可以避免自动连接。

3.2.4 初始化

在 ParameterServerStrategyV2Extended 初始化时候,会把传入的 variable_partitioner 设置到 _variable_partitioner 之中,也会配置参数服务器数目和 worker 数目。

class ParameterServerStrategyV2Extended(
    parameter_server_strategy.ParameterServerStrategyExtended):
  """Extended class for ParameterServerStrategyV2.

  Please see  tf.distribute.StrategyExtended  doc for more information.
  """

  def __init__(self, container_strategy, cluster_resolver,
               variable_partitioner):
    """Initialization of ParameterServerStrategyV2Extended."""
    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
    self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
        "worker", []))
    self._variable_count = 0

    self._variable_partitioner = variable_partitioner

2.5.3 构建

我们接下来看看创建过程,也就是如何把变量分片到不同参数服务器上。具体思路是:

  • 没有配置分区生成器的话,就用 RR 策略(_create_variable_round_robin)把变量分配到参数服务器之上。
  • 如果配置了分区生成器,则做如下操作:
    • 对 rank-0 不做分区。
    • 通过 _variable_partitioner 得到分区数目。
    • 分区数目需要大于第一维,否则用第一维。
    • 计算张量 offset。
    • 生成很多小张量。
    • 使用 _create_variable_round_robin 构建小张量列表。
    • 用小张量列表来生成 ShardedVariable。
  def _create_variable(self, next_creator, **kwargs):
    """Implements StrategyExtendedV2._create_variable.

    Creates a  Variable  or a  ShardedVariable . A  ShardedVariable  will be
    created if satisfying all the following criteria:
      1.  self._variable_partitioner  results in more than one partition on the
         first axis.
      2. variable's rank is greater than 0.
      3. variable is not colocated with another variable.
    Otherwise a  Variable  will be created.

    Args:
      next_creator: See  variable_scope.variable_creator_scope ; the next
        creator in the chain.
      **kwargs: Passed through to the next creator.

    Returns:
      A  Variable  or  ShardedVariable .
    """

    var_creator = self._create_var_creator(next_creator, **kwargs)
    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
      colocate_with = kwargs["colocate_with"]
      # Clear the variable scope to avoid possible conflicts between device
      # scope and colocation scope.
      with ops.device(None):
        with ops.colocate_with(colocate_with):
          var = var_creator(**kwargs)
          return var

    # 没有配置分区生成器的话,就用 RR 策略把变量分配到参数服务器之上
    if self._variable_partitioner is None:
      return self._create_variable_round_robin(var_creator, **kwargs)

  # 下面是配置了分区生成器
    name = kwargs.get("name", None)
    initial_value = kwargs.get("initial_value", None)

    # Two cases where initial_value can be a callable:
    #   1. initial_value is passed as a callable, e.g, an  initializer  class.
    #   2. restoring from checkpoint, initial_value is a
    #     "CheckpointInitialValueCallable".
    init_from_fn = callable(initial_value)

    dtype = kwargs.get("dtype", None)
    shape = kwargs.get("shape", None)
    if init_from_fn and (shape is None or dtype is None):
      init_from_fn = False
      initial_value = initial_value()
    if not init_from_fn:
      # The initial_value is created on coordinator, it will need to be sent to
      # ps for variable initialization, which can be inefficient and can
      # potentially hit the 2GB limit on protobuf serialization.
      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
      dtype = initial_value.dtype
      shape = initial_value.shape
    else:
      shape = tensor_shape.as_shape(shape)

    # rank-0 不做分区
    if shape.rank == 0:  # Skip partitioning rank-0 variable.
      return self._create_variable_round_robin(var_creator, **kwargs)

    # 得到分区数目
    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
    if num_partitions[0] == 1:  # no partition
      return self._create_variable_round_robin(var_creator, **kwargs)

    # 分区数目需要大于第一维,否则用第一维
    # Use "div" partition strategy to partition the variable.
    num_partitions = min(num_partitions[0], shape[0])
    base = shape[0] // num_partitions
    
    # 计算 offset
    extra = shape[0] % num_partitions
    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
    # offsets: [0, 3, 6, 8, 10]
    offsets = []
    for i in range(num_partitions):
      if i == 0:
        offsets.append(0)
      else:
        prev_shard_size = base + (1 if i - 1 < extra else 0)
        offsets.append(offsets[i - 1] + prev_shard_size)
    offsets.append(shape[0])

    def init_shard_fn(shard_index):
      if not init_from_fn:
        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
    
      partition_shape = (offsets[shard_index + 1] -
                         offsets[shard_index],) + shape[1:]
      partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
      arg_spec = tf_inspect.getfullargspec(initial_value)
      if ("shard_info" not in arg_spec.args and
          "shard_info" not in arg_spec.kwonlyargs):
        try:
          value = initial_value(
              partition_shape=partition_shape,
              partition_offset=partition_offset)
        except (TypeError, ValueError):
          # TypeError: Initializer doesn't accept kwargs
          # ValueError: Initializer doesn't accept partition kwargs
          # In both cases we go ahead creating the full value and then slice.
          value = initial_value()

        if value.shape == partition_shape:
          # Initializer supports partition: value is the partition value.
          return value
        else:
          # Initializer doesn't support partition: value is the full value
          # and needs to be sliced to get the partition value.
          return value[offsets[shard_index]:offsets[shard_index + 1]]
      else:
        # For compatibility with  CheckpointInitialValueCallable .
        return initial_value(
            shard_info=trackable.ShardInfo(
                shape=tensor_shape.as_shape(partition_shape),
                offset=partition_offset))

    # 生成很多小张量
    var_list = []
    for i in range(num_partitions):
      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
      kwargs["initial_value"] = lambda: init_shard_fn(i) # 初始化
      if name is not None:
        kwargs["name"] = "{}/part_{}".format(name, i)
      # 使用 _create_variable_round_robin 得到张量如何分配  
      var_list.append(self._create_variable_round_robin(var_creator, **kwargs))

    #用小张量列表来生成 ShardedVariable
    result = sharded_variable.ShardedVariable(var_list)
    return result

上面逻辑之中,两个分支都使用了 _create_variable_round_robin,其使用 RR 策略决定具体 placement 如何做。其实,就是给张量配置了对应的设备名字,后续做布局操作时候,就按照设备名字进行操作。

  def _create_variable_round_robin(self, next_creator, **kwargs):
    # Clear the colocation scope to avoid possible conflicts between device
    # scope and colocation scope.
    with ops.colocate_with(None, ignore_existing=True):
      # Explicitly set CPU:0 device for PS in case create variable is called
      # inside replica_fn and worker has with GPU:0 scope.
      with ops.device("/job:ps/task:%d/device:CPU:0" %
                      (self._variable_count % self._num_ps)):
        var = next_creator(**kwargs)
        logging.debug(
            "Creating variable (name:%s, shape:%r) on "
            "/job:ps/task:%d/device:CPU:0",
            var.name, var.shape, (self._variable_count % self._num_ps))
        self._variable_count += 1
        return var

_create_variable_round_robin 的参数 next_creator 一般来说是如下方法,这里使用了 AggregatingVariable 和 CachingVariable 来构建变量列表 var_list,然后才是利用 var_list 构建 ShardedVariable。我们主要介绍 AggregatingVariable。

  def _create_var_creator(self, next_creator, **kwargs):
    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)

    def var_creator(**kwargs):
      """Create an AggregatingVariable."""
      # Create and wrap the variable.
      v = next_creator(**kwargs)
      wrapped_v = ps_values.CachingVariable(v)
      wrapped = ps_values.AggregatingVariable(self._container_strategy(),
                                              wrapped_v, aggregation)
      return wrapped

    if self._num_replicas_in_sync > 1:
      if aggregation not in (
          vs.VariableAggregation.NONE,
          vs.VariableAggregation.SUM,
          vs.VariableAggregation.MEAN,
          vs.VariableAggregation.ONLY_FIRST_REPLICA
      ):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])
      return var_creator
    else:
      def variable_creator_single_replica(**kwargs):
        v = next_creator(**kwargs)
        return ps_values.CachingVariable(v)
      return variable_creator_single_replica

2.5.4 AggregatingVariable

AggregatingVariable 作用是对变量进行包装,该变量可以进行跨副本汇集更改。以 _assign_func 为例,可以看到,其使用 _distribute_strategy.extended.update 对变量进行操作。

# Variable used in PSStrategy TF 1, TF2 and CentralStorageStrategy.
class AggregatingVariable(resource_variable_ops.BaseResourceVariable,
                          core.Tensor):
  """A wrapper around a variable that aggregates updates across replicas."""

  def __init__(self, strategy, v, aggregation):
    self._distribute_strategy = strategy
    self._v = v
    # NOTE: We don't use "_distributed_container" here because we don't want
    # to trigger that code path in regroup().
    v._aggregating_container = weakref.ref(self)  # pylint: disable=protected-access
    self._aggregation = aggregation

  def __deepcopy__(self, memo):
    """Perform a deepcopy of the  AggregatingVariable .

    Unlike the deepcopy of a regular tf.Variable, this keeps the original
    strategy and devices of the  AggregatingVariable .  To avoid confusion
    with the behavior of deepcopy on a regular  Variable  (which does
    copy into new devices), we only allow a deepcopy of a  AggregatingVariable 
    within its originating strategy scope.

    Args:
      memo: The memoization object for  deepcopy .

    Returns:
      A deep copy of the current  AggregatingVariable .

    Raises:
      RuntimeError: If trying to deepcopy into a different strategy.
    """
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      v = copy.deepcopy(self._v, memo)

    copied_variable = type(self)(
        strategy=self._distribute_strategy,
        v=v,
        aggregation=self._aggregation)

    memo[id(self)] = copied_variable

    return copied_variable

  def get(self):
    return self._v

  @property
  def distribute_strategy(self):
    return self._distribute_strategy

  def __getattr__(self, name):
    return getattr(self._v, name)

  def _assign_func(self, *args, **kwargs):
    with ds_context.enter_or_assert_strategy(self._distribute_strategy):
      f = kwargs.pop("f")
      if ds_context.in_cross_replica_context():
        if distribute_lib.get_update_replica_id() is not None:
          # We are calling an assign function in an update context.
          return f(self._v, *args, **kwargs)

        # We are calling an assign function in cross replica context, wrap it in
        # an update call.
        return self._distribute_strategy.extended.update(
            self, f, args=args, kwargs=kwargs)
      else:
        replica_context = ds_context.get_replica_context()
          # We are calling an assign function in replica context.
        # We reduce the value we want to assign/add/sub. More details about how
        # we handle the different use cases can be found in the _reduce method.
        # We call the function with the reduced value.
        if self._aggregation == vs.VariableAggregation.NONE:
          raise ValueError(
              values_util.aggregation_error_msg.format(
                  variable_type="AggregatingVariable"))

        def merge_fn(strategy,
                     value,
                     use_locking=False,
                     name=None,
                     read_value=True):
          v = values_util.apply_aggregation(strategy, value, self._aggregation,
                                            self)
          if name and isinstance(name, values.PerReplica):
            name = name.values[0]
          return strategy.extended.update(
              self,
              f,
              args=(v,),
              kwargs={
                  "use_locking": use_locking,
                  "name": name,
                  "read_value": read_value
              })
        return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)

2.6 使用

下面示例展示了 ShardedVariable 如何使用。在 Dense 之中构建了一个 ShardedVariable,就是 self.w,其 shape 是 [100, 10],分区之后的结果是两个 (50, 10) 的张量。

  class Dense(tf.Module):
    def __init__(self, name=None):
      super().__init__(name=name)
      self.w = tf.Variable(tf.random.normal([100, 10]), name='w')

    def __call__(self, x):
      return x * self.w

  # Partition the dense layer into 2 shards.
  variable_partitioner = (
    tf.distribute.experimental.partitioners.FixedShardsPartitioner(
      num_shards = 2))
  strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver=...,
    variable_partitioner = variable_partitioner)
  with strategy.scope():
    dense = Dense() # 位于 strategy 上下文之中,于是生成的变量被自动分成 2 个分区。
    
  assert len(dense.variables) == 2
  assert isinstance(dense.variables[0], tf.Variable)
  assert isinstance(dense.variables[1], tf.Variable)
  assert dense.variables[0].shape == (50, 10)
  assert dense.variables[1].shape == (50, 10)

ShardedVariable 也是一种形式上的模型并行,比如把 AB 这个矩阵分解到两个参数服务器之上,分别与 C 相乘,最后把相乘结果在 worker 上聚合起来, concatenation 成一个最终结果张量。

图 23 合并张量

0xFF 参考

tensorflow源码解析之distributed_runtime

TensorFlow分布式训练

TensorFlow内核剖析

源代码

Tensorflow分布式原理理解

TensorFlow架构与设计:概述

Tensorflow 跨设备通信

TensorFlow 篇 | TensorFlow 2.x 分布式训练概览

《用TensorFlow 2.4 实现分布式训练》周玥枫 https://www.bilibili.com/video/BV1MT4y1M7Ym

深入 TensorFlow:参数服务器训练 https://www.bilibili.com/video/BV1u5411H798

相关文章: