在多个设备上扩展#
本指南演示如何使用 Flax NNX Module
在多个设备和主机(如 GPU、Google TPU 和 CPU)上扩展,例如使用 JAX 实时编译机制 (jax.jit
).
概述#
Flax 依赖 JAX 进行数值计算,并跨多个设备(如 GPU 和 TPU)扩展计算。扩展的核心是 JAX 实时编译器 jax.jit
。在本指南中,您将使用 Flax 自带的 nnx.jit
,它包装了 jax.jit
并更方便地与 Flax NNX Module
协同工作。
JAX 编译遵循 单程序多数据 (SPMD) 范式。这意味着您像在单个设备上运行一样编写 Python 代码,而 jax.jit
会自动编译并在多个设备上运行它。
为了确保编译性能,您通常需要指示 JAX 如何跨设备分片模型的变量。这就是 Flax NNX 的分片元数据 API - flax.nnx.spmd
- 的作用。它可以帮助您用此信息标注模型变量。
注意:面向 Flax Linen 用户:
flax.nnx.spmd
API 与 Linen Flax 上的(p)jit
指南 中的模型定义级别描述的类似。但是,由于 Flax NNX 带来的优势,Flax NNX 中的顶层代码更简单,一些文本解释也将更加更新和清晰。
如果您不熟悉 JAX 中的并行化,可以从以下教程中了解有关用于扩展的 JAX API 的更多信息
并行编程简介:一个 101 级教程,介绍使用
jax.jit
进行自动并行化、使用jax.jit
和jax.lax.with_sharding_constraint
进行半自动并行化以及使用shard_map
进行手动分片的基本知识。分布式数组和自动并行化:一个关于使用
jax.jit
和jax.lax.with_sharding_constraint
进行并行化的更详细教程。在学习 101 之后再学习此内容。使用
shard_map
进行手动并行:另一个更深入的文档,遵循 101。
设置#
导入一些必要的依赖项。
注意:本指南使用 --xla_force_host_platform_device_count=8
标志在 Google Colab/Jupyter Notebook 中的 CPU 环境中模拟多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import *
import numpy as np
import jax
from jax import numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import nnx
import optax # Optax for common losses and optimizers.
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
以下代码展示了如何按照 JAX 的 分布式数组和自动并行化 指南导入和设置 JAX 级别设备 API
使用 JAX
jax.sharding.Mesh
启动 2x4 设备mesh
(8 个设备)。此布局与 TPU v3-8(也是 8 个设备)相同。使用
axis_names
参数为每个轴标注名称。标注轴名称的典型方法是axis_name=('data', 'model')
,其中
'data'
:用于对输入和激活的批次维度进行数据并行分片的网格维度。'model'
:用于跨设备分片模型参数的网格维度。
# Create a mesh of two dimensions and annotate each axis with a name.
mesh = Mesh(devices=np.array(jax.devices()).reshape(2, 4),
axis_names=('data', 'model'))
print(mesh)
Mesh('data': 2, 'model': 4)
定义具有指定分片的模型#
接下来,创建一个名为 DotReluDot
的示例层,它继承了 Flax nnx.Module
。此层对输入 x
进行两次点积运算,并在两者之间使用 jax.nn.relu
(ReLU)激活函数。
要使用其理想分片标注模型变量,您可以使用 flax.nnx.with_partitioning
包装其初始化函数。本质上,这调用 flax.nnx.with_metadata
,它向相应的 nnx.Variable
添加一个 .sharding
属性字段。
注意:此标注将在 Flax NNX 中的提升变换中被保留并相应地调整。这意味着如果您将分片标注与任何修改轴的变换(如
nnx.vmap
、nnx.scan
)一起使用,则需要通过transform_metadata
参数提供该额外轴的分片。请查看 Flax NNX 变换(transforms)指南 以了解更多信息。
class DotReluDot(nnx.Module):
def __init__(self, depth: int, rngs: nnx.Rngs):
init_fn = nnx.initializers.lecun_normal()
# Initialize a sublayer `self.dot1` and annotate its kernel with.
# `sharding (None, 'model')`.
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
use_bias=False, # or use `bias_init` to give it annotation too
rngs=rngs)
# Initialize a weight param `w2` and annotate with sharding ('model', None).
# Note that this is simply adding `.sharding` to the variable as metadata!
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
sharding=('model', None),
)
def __call__(self, x: jax.Array):
y = self.dot1(x)
y = jax.nn.relu(y)
# In data parallelism, input / intermediate value's first dimension (batch)
# will be sharded on `data` axis
y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
z = jnp.dot(y, self.w2.value)
return z
了解分片名称#
所谓的“分片标注”本质上是设备轴名称元组,如 'data'
、'model'
或 None
。这描述了此 JAX 数组的每个维度应如何分片 - 跨设备网格维度之一分片,或者根本不分片。
因此,当您定义形状为 (depth, depth)
且标注为 (None, 'model')
的 W1
时
第一个维度将在所有设备上复制。
第二个维度将在设备网格的
'model'
轴上分片。这意味着W1
将在设备(0, 4)
、(1, 5)
、(2, 6)
和(3, 7)
上进行 4 路分片,在此维度上。
JAX 的 分布式数组和自动并行化 指南提供了更多示例和解释。
初始化分片模型#
现在,您已经在 nnx.Variable
上附加了注释,但实际的权重尚未进行分片。如果您直接创建此模型,所有 JAX 数组仍然停留在设备 0
上。在实际应用中,您可能希望避免这种情况,因为大型模型在这种情况下会“OOM”(导致设备内存不足),而其他设备未被利用。
unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))
# You have annotations sticked there, yay!
print(unsharded_model.dot1.kernel.sharding) # (None, 'model')
print(unsharded_model.w2.sharding) # ('model', None)
# But the actual arrays are not sharded?
print(unsharded_model.dot1.kernel.value.sharding) # SingleDeviceSharding
print(unsharded_model.w2.value.sharding) # SingleDeviceSharding
(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
这里,您应该通过 nnx.jit
利用 JAX 的编译机制来创建分片模型。关键是在一个 jit
函数中初始化模型,并在模型状态上分配分片。
使用
nnx.get_partition_spec
来剥离模型变量上附加的.sharding
注释。调用
jax.lax.with_sharding_constraint
将模型状态与分片注释绑定。此 API 告诉顶层jit
如何对变量进行分片!丢弃未分片的状态,并根据分片状态返回模型。
使用
nnx.jit
编译整个函数,这将允许输出成为一个有状态的 NNX 模块。在设备网格上下文下运行它,以便 JAX 知道将它分片到哪些设备上。
整个编译后的 create_sharded_model()
函数将直接生成一个具有分片 JAX 数组的模型,并且不会发生任何单设备“OOM”!
@nnx.jit
def create_sharded_model():
model = DotReluDot(1024, rngs=nnx.Rngs(0)) # Unsharded at this moment.
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
with mesh:
sharded_model = create_sharded_model()
# They are some `GSPMDSharding` now - not a single device!
print(sharded_model.dot1.kernel.value.sharding)
print(sharded_model.w2.value.sharding)
# Check out their equivalency with some easier-to-read sharding descriptions
assert sharded_model.dot1.kernel.value.sharding.is_equivalent_to(
NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_model.w2.value.sharding.is_equivalent_to(
NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)
您可以使用 jax.debug.visualize_array_sharding
查看任何一维或二维数组的分片。
print("sharded_model.dot1.kernel (None, 'model') :")
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)
print("sharded_model.w2 ('model', None) :")
jax.debug.visualize_array_sharding(sharded_model.w2.value)
sharded_model.dot1.kernel (None, 'model') :
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
sharded_model.w2 ('model', None) :
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
关于 jax.lax.with_sharding_constraint
(半自动并行化)#
对 JAX 数组进行分片的关键是在 jax.jit
函数中调用 jax.lax.with_sharding_constraint
。请注意,如果不在 JAX 设备网格上下文下,它将抛出错误。
注意: JAX 文档中的 并行编程简介 和 分布式数组和自动并行化 详细介绍了使用
jax.jit
进行自动并行化,以及使用jax.jit
和 `jax.lax.with_sharding_constraint 进行半自动并行化。
您可能已经注意到,您还在模型定义中使用了一次 jax.lax.with_sharding_constraint
,来约束中间值的分配。这仅仅是为了说明,如果您想显式地对不是模型变量的值进行分片,您可以始终与 Flax NNX API 正交地使用它。
这就带来了一个问题:为什么还要使用 Flax NNX 注释 API?为什么不在模型定义中添加 JAX 分片约束?最主要的原因是,您仍然需要显式注释才能从磁盘上的检查点加载分片模型。这将在下一节中介绍。
从检查点加载分片模型#
现在您可以初始化一个分片模型,而不会出现 OOM,但如何从磁盘上的检查点加载它呢?JAX 检查点库(如 Orbax)通常支持加载分片模型,前提是给出了分片树。
您可以使用 nnx.get_named_sharding
生成这样的分片树。为了避免任何实际的内存分配,使用 nnx.eval_shape
变换来生成一个抽象 JAX 数组的模型,并且只使用它的 .sharding
注释来获取分片树。
下面是一个使用 Orbax 的 StandardCheckpointer
API 的示例演示。查看 Orbax 网站 以了解他们的最新更新和推荐的 API。
import orbax.checkpoint as ocp
# Save the sharded state.
sharded_state = nnx.state(sharded_model)
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(path / 'checkpoint_name', sharded_state)
# Load a sharded state from checkpoint, without `sharded_model` or `sharded_state`.
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model)
# Orbax API expects a tree of abstract `jax.ShapeDtypeStruct`
# that contains both sharding and the shape/dtype of the arrays.
abs_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_state, nnx.get_named_sharding(abs_state, mesh)
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
args=ocp.args.StandardRestore(abs_state))
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
编译训练循环#
现在,从初始化或从检查点,您都拥有了一个分片模型。为了执行编译后的、扩展的训练,您还需要对输入进行分片。在这个数据并行性示例中,训练数据在其批次维度上跨 data
设备轴进行分片,因此您应该将数据放入分片 ('data', None)
中。您可以使用 jax.device_put
来实现这一点。
请注意,如果所有输入都具有正确的分片,即使没有 JIT 编译,输出也会以最自然的方式进行分片。在下面的示例中,即使没有在输出 y
上使用 jax.lax.with_sharding_constraint
,它仍然被分片为 ('data', None)
。
如果您想知道原因:
DotReluDot.__call__
的第二个矩阵乘法有两个输入,分片分别为('data', 'model')
和('model', None)
,其中两个输入的收缩轴都是model
。因此,发生了一个 reduce-scatter 矩阵乘法,它会自然地将输出分片为('data', None)
。如果您想在低级别学习它是如何发生的,请查看 JAX shard map 集合指南 及其示例。
# In data parallelism, the first dimension (batch) will be sharded on `data` axis.
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)
with mesh:
output = sharded_model(input)
print(output.shape)
jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None)
(8, 1024)
CPU 0,1,2,3 CPU 4,5,6,7
现在,训练循环的其余部分非常传统 - 几乎与 NNX 基础 中的示例相同,只是输入和标签也进行了显式分片。
nnx.jit
会根据输入的分片方式自动调整并选择最佳布局,因此请尝试为自己的模型和输入使用不同的分片方式。
optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3)) # reference sharing
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model: DotReluDot):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
input = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
label = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)
with mesh:
for i in range(5):
loss = train_step(sharded_model, optimizer, input, label)
print(loss) # Model (over-)fitting to the labels quickly.
1.455235
0.7646729
0.50971293
0.378493
0.28089797
性能分析#
如果您在 TPU pod 或 pod 切片上运行,您可以创建一个自定义的 block_all()
实用程序函数,如下所示,来衡量性能。
%%timeit
def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs
with mesh:
new_state = block_all(train_step(sharded_model, optimizer, input, label))
7.09 ms ± 390 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
逻辑轴注释#
JAX 的 自动 SPMD 鼓励用户探索不同的分片布局,以找到最佳布局。为此,在 Flax 中,您可以选择使用更具描述性的轴名称进行注释(不仅仅是设备网格轴名称,例如 'data'
和 'model'
),只要您提供从别名到设备网格轴的映射。
您可以将映射与注释一起作为对应 nnx.Variable
的另一个元数据提供,或者在顶层覆盖它。查看下面的 LogicalDotReluDot
示例。
# The mapping from alias annotation to the device mesh.
sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))
class LogicalDotReluDot(nnx.Module):
def __init__(self, depth: int, rngs: nnx.Rngs):
init_fn = nnx.initializers.lecun_normal()
# Initialize a sublayer `self.dot1`.
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_metadata(
# Provide the sharding rules here.
init_fn, sharding=('embed', 'hidden'), sharding_rules=sharding_rules),
use_bias=False,
rngs=rngs)
# Initialize a weight param `w2`.
self.w2 = nnx.Param(
# Didn't provide the sharding rules here to show you how to overwrite it later.
nnx.with_metadata(init_fn, sharding=('hidden', 'embed'))(
rngs.params(), (depth, depth))
)
def __call__(self, x: jax.Array):
y = self.dot1(x)
y = jax.nn.relu(y)
# Unfortunately the logical aliasing doesn't work on lower-level JAX calls.
y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', None))
z = jnp.dot(y, self.w2.value)
return z
如果您没有在模型定义中提供所有 sharding_rule
注释,您可以在调用 nnx.State
或 nnx.get_partition_spec
或 nnx.get_named_sharding
之前,编写几行代码将其添加到模型的 nnx.State
中。
def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState:
vs.sharding_rules = sharding_rules
return vs
@nnx.jit
def create_sharded_logical_model():
model = LogicalDotReluDot(1024, rngs=nnx.Rngs(0))
state = nnx.state(model)
state = jax.tree.map(add_sharding_rule, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
return model
with mesh:
sharded_logical_model = create_sharded_logical_model()
jax.debug.visualize_array_sharding(sharded_logical_model.dot1.kernel.value)
jax.debug.visualize_array_sharding(sharded_logical_model.w2.value)
# Check out their equivalency with some easier-to-read sharding descriptions.
assert sharded_logical_model.dot1.kernel.value.sharding.is_equivalent_to(
NamedSharding(mesh, PartitionSpec(None, 'model')), ndim=2
)
assert sharded_logical_model.w2.value.sharding.is_equivalent_to(
NamedSharding(mesh, PartitionSpec('model', None)), ndim=2
)
with mesh:
logical_output = sharded_logical_model(input)
assert logical_output.sharding.is_equivalent_to(
NamedSharding(mesh, PartitionSpec('data', None)), ndim=2
)
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
何时使用设备轴/逻辑轴#
选择何时使用设备轴或逻辑轴取决于您想要控制模型分片的程度。
设备网格轴:
对于更简单的模型,这可以为您节省几行代码,这些代码用于将逻辑命名转换回设备命名。
中间激活值的分配只能通过
jax.lax.with_sharding_constraint
和设备网格轴来完成。因此,如果您想要对模型的分配进行非常精细的控制,那么在所有地方直接使用设备网格轴名称可能会不太混乱。
逻辑命名:如果您想进行试验并为模型权重找到最优的分配布局,这将非常有用。