在多个设备上扩展#

本指南演示如何使用 Flax NNX Module 在多个设备和主机(如 GPU、Google TPU 和 CPU)上扩展,例如使用 JAX 实时编译机制 (jax.jit).

概述#

Flax 依赖 JAX 进行数值计算,并跨多个设备(如 GPU 和 TPU)扩展计算。扩展的核心是 JAX 实时编译器 jax.jit。在本指南中,您将使用 Flax 自带的 nnx.jit,它包装了 jax.jit 并更方便地与 Flax NNX Module 协同工作。

JAX 编译遵循 单程序多数据 (SPMD) 范式。这意味着您像在单个设备上运行一样编写 Python 代码,而 jax.jit 会自动编译并在多个设备上运行它。

为了确保编译性能,您通常需要指示 JAX 如何跨设备分片模型的变量。这就是 Flax NNX 的分片元数据 API - flax.nnx.spmd - 的作用。它可以帮助您用此信息标注模型变量。

注意:面向 Flax Linen 用户flax.nnx.spmd API 与 Linen Flax 上的 (p)jit 指南 中的模型定义级别描述的类似。但是,由于 Flax NNX 带来的优势,Flax NNX 中的顶层代码更简单,一些文本解释也将更加更新和清晰。

如果您不熟悉 JAX 中的并行化,可以从以下教程中了解有关用于扩展的 JAX API 的更多信息

设置#

导入一些必要的依赖项。

注意:本指南使用 --xla_force_host_platform_device_count=8 标志在 Google Colab/Jupyter Notebook 中的 CPU 环境中模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import *

import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from flax import nnx

import optax # Optax for common losses and optimizers.
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下代码展示了如何按照 JAX 的 分布式数组和自动并行化 指南导入和设置 JAX 级别设备 API

  1. 使用 JAX jax.sharding.Mesh 启动 2x4 设备 mesh(8 个设备)。此布局与 TPU v3-8(也是 8 个设备)相同。

  2. 使用 axis_names 参数为每个轴标注名称。标注轴名称的典型方法是 axis_name=('data', 'model'),其中

  • 'data':用于对输入和激活的批次维度进行数据并行分片的网格维度。

  • 'model':用于跨设备分片模型参数的网格维度。

# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
            axis_names=('data', 'model'))
print(mesh)
Mesh('data': 2, 'model': 4)

定义具有指定分片的模型#

接下来,创建一个名为 DotReluDot 的示例层,它继承了 Flax nnx.Module。此层对输入 x 进行两次点积运算,并在两者之间使用 jax.nn.relu(ReLU)激活函数。

要使用其理想分片标注模型变量,您可以使用 flax.nnx.with_partitioning 包装其初始化函数。本质上,这调用 flax.nnx.with_metadata,它向相应的 nnx.Variable 添加一个 .sharding 属性字段。

注意:此标注将在 Flax NNX 中的提升变换中被保留并相应地调整。这意味着如果您将分片标注与任何修改轴的变换(如 nnx.vmapnnx.scan)一起使用,则需要通过 transform_metadata 参数提供该额外轴的分片。请查看 Flax NNX 变换(transforms)指南 以了解更多信息。

class DotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1` and annotate its kernel with.
    # `sharding (None, 'model')`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
      use_bias=False,  # or use `bias_init` to give it annotation too
      rngs=rngs)

    # Initialize a weight param `w2` and annotate with sharding ('model', None).
    # Note that this is simply adding `.sharding` to the variable as metadata!
    self.w2 = nnx.Param(
      init_fn(rngs.params(), (depth, depth)),  # RNG key and shape for W2 creation
      sharding=('model', None),
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # In data parallelism, input / intermediate value's first dimension (batch)
    # will be sharded on `data` axis
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
    z = jnp.dot(y, self.w2.value)
    return z

了解分片名称#

所谓的“分片标注”本质上是设备轴名称元组,如 'data''model'None。这描述了此 JAX 数组的每个维度应如何分片 - 跨设备网格维度之一分片,或者根本不分片。

因此,当您定义形状为 (depth, depth) 且标注为 (None, 'model')W1

  • 第一个维度将在所有设备上复制。

  • 第二个维度将在设备网格的 'model' 轴上分片。这意味着 W1 将在设备 (0, 4)(1, 5)(2, 6)(3, 7) 上进行 4 路分片,在此维度上。

JAX 的 分布式数组和自动并行化 指南提供了更多示例和解释。

初始化分片模型#

现在,您已经在 nnx.Variable 上附加了注释,但实际的权重尚未进行分片。如果您直接创建此模型,所有 JAX 数组仍然停留在设备 0 上。在实际应用中,您可能希望避免这种情况,因为大型模型在这种情况下会“OOM”(导致设备内存不足),而其他设备未被利用。

unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))

# You have annotations sticked there, yay!
print(unsharded_model.dot1.kernel.sharding)     # (None, 'model')
print(unsharded_model.w2.sharding)              # ('model', None)

# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding)  # SingleDeviceSharding
print(unsharded_model.w2.value.sharding)           # SingleDeviceSharding
(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

这里,您应该通过 nnx.jit 利用 JAX 的编译机制来创建分片模型。关键是在一个 jit 函数中初始化模型,并在模型状态上分配分片。

  1. 使用 nnx.get_partition_spec 来剥离模型变量上附加的 .sharding 注释。

  2. 调用 jax.lax.with_sharding_constraint 将模型状态与分片注释绑定。此 API 告诉顶层 jit 如何对变量进行分片!

  3. 丢弃未分片的状态,并根据分片状态返回模型。

  4. 使用 nnx.jit 编译整个函数,这将允许输出成为一个有状态的 NNX 模块。

  5. 在设备网格上下文下运行它,以便 JAX 知道将它分片到哪些设备上。

整个编译后的 create_sharded_model() 函数将直接生成一个具有分片 JAX 数组的模型,并且不会发生任何单设备“OOM”!

@nnx.jit
def create_sharded_model():
  model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)

# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)

您可以使用 jax.debug.visualize_array_sharding 查看任何一维或二维数组的分片。

print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)
sharded_model.dot1.kernel (None, 'model') :
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
sharded_model.w2 ('model', None) :
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

关于 jax.lax.with_sharding_constraint(半自动并行化)#

对 JAX 数组进行分片的关键是在 jax.jit 函数中调用 jax.lax.with_sharding_constraint。请注意,如果不在 JAX 设备网格上下文下,它将抛出错误。

注意: JAX 文档中的 并行编程简介分布式数组和自动并行化 详细介绍了使用 jax.jit 进行自动并行化,以及使用 jax.jit`jax.lax.with_sharding_constraint 进行半自动并行化。

您可能已经注意到,您还在模型定义中使用了一次 jax.lax.with_sharding_constraint,来约束中间值的分配。这仅仅是为了说明,如果您想显式地对不是模型变量的值进行分片,您可以始终与 Flax NNX API 正交地使用它。

这就带来了一个问题:为什么还要使用 Flax NNX 注释 API?为什么不在模型定义中添加 JAX 分片约束?最主要的原因是,您仍然需要显式注释才能从磁盘上的检查点加载分片模型。这将在下一节中介绍。

从检查点加载分片模型#

现在您可以初始化一个分片模型,而不会出现 OOM,但如何从磁盘上的检查点加载它呢?JAX 检查点库(如 Orbax)通常支持加载分片模型,前提是给出了分片树。

您可以使用 nnx.get_named_sharding 生成这样的分片树。为了避免任何实际的内存分配,使用 nnx.eval_shape 变换来生成一个抽象 JAX 数组的模型,并且只使用它的 .sharding 注释来获取分片树。

下面是一个使用 Orbax 的 StandardCheckpointer API 的示例演示。查看 Orbax 网站 以了解他们的最新更新和推荐的 API。

import orbax.checkpoint as ocp

# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)

# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
  lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
  abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
                                      args=ocp.args.StandardRestore(abs_state))
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

编译训练循环#

现在,从初始化或从检查点,您都拥有了一个分片模型。为了执行编译后的、扩展的训练,您还需要对输入进行分片。在这个数据并行性示例中,训练数据在其批次维度上跨 data 设备轴进行分片,因此您应该将数据放入分片 ('data', None) 中。您可以使用 jax.device_put 来实现这一点。

请注意,如果所有输入都具有正确的分片,即使没有 JIT 编译,输出也会以最自然的方式进行分片。在下面的示例中,即使没有在输出 y 上使用 jax.lax.with_sharding_constraint,它仍然被分片为 ('data', None)

如果您想知道原因:DotReluDot.__call__ 的第二个矩阵乘法有两个输入,分片分别为 ('data', 'model')('model', None),其中两个输入的收缩轴都是 model。因此,发生了一个 reduce-scatter 矩阵乘法,它会自然地将输出分片为 ('data', None)。如果您想在低级别学习它是如何发生的,请查看 JAX shard map 集合指南 及其示例。

# In data parallelism, the first dimension (batch) will be sharded on `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)

with mesh:
  output = sharded_model(input)
print(output.shape)
jax.debug.visualize_array_sharding(output)  # Also sharded as ('data', None)
(8, 1024)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

现在,训练循环的其余部分非常传统 - 几乎与 NNX 基础 中的示例相同,只是输入和标签也进行了显式分片。

nnx.jit 会根据输入的分片方式自动调整并选择最佳布局,因此请尝试为自己的模型和输入使用不同的分片方式。

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))  # reference sharing

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model: DotReluDot):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

with mesh:
  for i in range(5):
    loss = train_step(sharded_model, optimizer, input, label)
    print(loss)    # Model (over-)fitting to the labels quickly.
1.455235
0.7646729
0.50971293
0.378493
0.28089797

性能分析#

如果您在 TPU pod 或 pod 切片上运行,您可以创建一个自定义的 block_all() 实用程序函数,如下所示,来衡量性能。

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(sharded_model, optimizer, input, label))
7.09 ms ± 390 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

逻辑轴注释#

JAX 的 自动 SPMD 鼓励用户探索不同的分片布局,以找到最佳布局。为此,在 Flax 中,您可以选择使用更具描述性的轴名称进行注释(不仅仅是设备网格轴名称,例如 'data''model'),只要您提供从别名到设备网格轴的映射。

您可以将映射与注释一起作为对应 nnx.Variable 的另一个元数据提供,或者在顶层覆盖它。查看下面的 LogicalDotReluDot 示例。

# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))

class LogicalDotReluDot(nnx.Module):
  def __init__(self, depth: int, rngs: nnx.Rngs):
    init_fn = nnx.initializers.lecun_normal()

    # Initialize a sublayer `self.dot1`.
    self.dot1 = nnx.Linear(
      depth, depth,
      kernel_init=nnx.with_metadata(
        # Provide the sharding rules here.
        init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
      use_bias=False,
      rngs=rngs)

    # Initialize a weight param `w2`.
    self.w2 = nnx.Param(
      # Didn't provide the sharding rules here to show you how to overwrite it later.
      nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
        rngs.params(), (depth, depth))
    )

  def __call__(self, x: jax.Array):
    y = self.dot1(x)
    y = jax.nn.relu(y)
    # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
    y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
    z = jnp.dot(y, self.w2.value)
    return z

如果您没有在模型定义中提供所有 sharding_rule 注释,您可以在调用 nnx.Statennx.get_partition_specnnx.get_named_sharding 之前,编写几行代码将其添加到模型的 nnx.State 中。

def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:
  vs.sharding_rules = sharding_rules
  return vs

@nnx.jit
def create_sharded_logical_model():
  model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
  state = nnx.state(model)
  state = jax.tree.map(add_sharding_rule, state,
                       is_leaf=lambda x: isinstance(x, nnx.VariableState))
  pspecs = nnx.get_partition_spec(state)
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)
  return model

with mesh:
  sharded_logical_model = create_sharded_logical_model()

jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)

# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
  NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)

with mesh:
  logical_output = sharded_logical_model(input)
  assert logical_output.sharding.is_equivalent_to(
    NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
  )
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

何时使用设备轴/逻辑轴#

选择何时使用设备轴或逻辑轴取决于您想要控制模型分片的程度。

  • 设备网格轴:

    • 对于更简单的模型,这可以为您节省几行代码,这些代码用于将逻辑命名转换回设备命名。

    • 中间激活值的分配只能通过 jax.lax.with_sharding_constraint 和设备网格轴来完成。因此,如果您想要对模型的分配进行非常精细的控制,那么在所有地方直接使用设备网格轴名称可能会不太混乱。

  • 逻辑命名:如果您想进行试验并为模型权重找到最优的分配布局,这将非常有用。