将 Flax NNX 和 Linen 一起使用#
本指南适用于想要将代码库混合使用 Flax Linen 和 Flax NNX Module
的现有 Flax 用户,这得益于 flax.nnx.bridge
API。
如果您希望:
逐步将您的代码库迁移到 NNX,一次迁移一个模块;
拥有已迁移到 NNX 但您尚未迁移的外部依赖项,或者您已迁移到 NNX 但外部依赖项仍在 Linen 中。
我们希望这能让您以自己的速度进行迁移和尝试 NNX,并利用两者的优势。我们还将讨论如何解决在几个方面,它们在本质上是不同的,这两个 API 交互操作的注意事项。
注意:
本指南介绍如何将 Linen 和 NNX 模块粘合在一起。要将现有的 Linen 模块迁移到 NNX,请查看 从 Flax Linen 迁移到 Flax NNX 指南。
所有内置的 Linen 层都应该有等效的 NNX 版本!查看 内置 NNX 层 的列表。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from flax import nnx
from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from typing import *
子模块就是你所需要的#
Flax 模型始终是模块树——要么是旧的 Linen 模块(flax.linen.Module
,通常写成 nn.Module
),要么是 NNX 模块(nnx.Module
)。
一个 nnx.bridge
包装器将两种类型粘合在一起,两种方式都可以:
nnx.bridge.ToNNX
:将 Linen 模块转换为 NNX,以便它可以作为另一个 NNX 模块的子模块,或者独立存在以在 NNX 样式的训练循环中进行训练。nnx.bridge.ToLinen
:反之亦然,将 NNX 模块转换为 Linen。
这意味着您可以采用自上而下或自下而上的行为:将整个 Linen 模块转换为 NNX,然后逐步向下移动,或者将所有较低级别的模块转换为 NNX,然后向上移动。
基础#
Linen 和 NNX 模块之间存在两个基本区别:
无状态与有状态:Linen 模块实例是无状态的:变量从纯函数
.init()
调用中返回,并单独管理。然而,NNX 模块将自己的变量作为实例属性拥有。延迟与急切:Linen 模块仅在实际看到输入时才分配空间来创建变量。而 NNX 模块实例在实例化时就创建其变量,而不会看到示例输入。
考虑到这一点,让我们看看 nnx.bridge
包装器是如何解决这些差异的。
Linen -> NNX#
由于 Linen 模块可能需要输入来创建变量,因此我们在从 Linen 转换而来的 NNX 模块中半正式地支持延迟初始化。当您提供示例输入时,Linen 变量就会创建。
对您来说,就是在您在 Linen 代码中调用 module.init()
的地方调用 nnx.bridge.lazy_init()
。
(注意:您可以对任何 NNX 模块调用 nnx.display
来检查其所有变量和状态。)
class LinenDot(nn.Module):
out_dim: int
w_init: Callable[..., Any] = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
# Linen might need the input shape to create the weight!
w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))
return x @ w
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64),
rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen
bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen
y = model(x) # => `y = model.apply(var, x)` in Linen
nnx.display(model)
# In-place swap your weight array and the model still works!
model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))
nnx.bridge.lazy_init
即使顶层模块是纯 NNX 模块也能正常工作,因此您可以按您希望的方式进行子模块化。
class NNXOuter(nnx.Module):
def __init__(self, out_dim: int, rngs: nnx.Rngs):
self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)
self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))
def __call__(self, x):
return self.dot(x) + self.b
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line
nnx.display(model)
Linen 权重已经转换为典型的 NNX 变量,它是实际 JAX 数组值的薄包装器。这里,w
是一个 nnx.Param
,因为它属于 LinenDot
模块的 params
集合。
我们将在 NNX 变量 <-> Linen 集合 部分详细介绍不同的集合和类型。现在,您只需知道它们被转换为与本机变量类似的 NNX 变量。
assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)
如果您在不使用 nnx.bridge.lazy_init
的情况下创建此模型,则在外部定义的 NNX 变量将按常例初始化,但 Linen 部分(包装在 ToNNX
中)不会初始化。
partial_model = NNXOuter(64, rngs=nnx.Rngs(0))
nnx.display(partial_model)
full_model = bridge.lazy_init(partial_model, x)
nnx.display(full_model)
NNX -> Linen#
要将 NNX 模块转换为 Linen,您应该将您的创建参数转发到 bridge.ToLinen
并让它处理实际的创建过程。
这是因为 NNX 模块实例在创建时急切地初始化所有变量,这会消耗内存和计算资源。另一方面,Linen 模块是无状态的,典型的 init
和 apply
过程涉及多次创建它们。因此 bridge.to_linen
将处理实际的模块创建并确保不会两次分配内存。
class NNXDot(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(
rngs.params(), (in_dim, out_dim)))
def __call__(self, x: jax.Array):
return x @ self.w
x = jax.random.normal(jax.random.key(42), (4, 32))
# Pass in the arguments, not an actual module
model = bridge.to_linen(NNXDot, 32, out_dim=64)
variables = model.init(jax.random.key(0), x)
y = model.apply(variables, x)
print(list(variables.keys()))
print(variables['params']['w'].shape) # => (32, 64)
print(y.shape) # => (4, 64)
['nnx', 'params']
(32, 64)
(4, 64)
请注意,ToLinen
模块需要跟踪一个额外的变量集合——nnx
——用于底层 NNX 模块的静态元数据。
# This new field stores the static data that defines the underlying `NNXDot`
print(type(variables['nnx']['graphdef'])) # => `nnx.graph.NodeDef`
<class 'flax.nnx.graph.NodeDef'>
bridge.to_linen
实际上是 Linen 模块 bridge.ToLinen
的便捷包装器。您很可能根本不需要直接使用 ToLinen
,除非您使用 ToLinen
的内置参数之一。例如,如果您的 NNX 模块不想使用 RNG 处理进行初始化:
class NNXAddConstant(nnx.Module):
def __init__(self):
self.constant = nnx.Variable(jnp.array(1))
def __call__(self, x):
return x + self.constant
# You have to use `skip_rng=True` because this module's `__init__` don't
# take `rng` as argument
model = bridge.ToLinen(NNXAddConstant, skip_rng=True)
y, var = model.init_with_output(jax.random.key(0), x)
与 ToNNX
类似,您可以使用 ToLinen
来创建一个另一个 Linen 模块的子模块。
class LinenOuter(nn.Module):
out_dim: int
@nn.compact
def __call__(self, x):
dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim)
b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim))
return dot(x) + b
x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenOuter(out_dim=64)
y, variables = model.init_with_output(jax.random.key(0), x)
w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']
print(w.shape, b.shape, y.shape)
(32, 64) (1, 64) (4, 64)
处理 RNG 密钥#
所有 Flax 模块,无论是 Linen 还是 NNX,都会自动处理用于变量创建和随机层(如 dropout)的 RNG 密钥。但是,RNG 密钥拆分的具体逻辑是不同的,因此您无法在 Linen 和 NNX 模块之间生成相同的参数,即使您传递相同的密钥。
另一个区别是 NNX 模块是有状态的,因此它们可以在自身内部跟踪和更新 RNG 密钥。
Linen 到 NNX#
如果您将 Linen 模块转换为 NNX,您将享受有状态的好处,并且不需要在每次模块调用时传递额外的 RNG 密钥。您可以始终使用 nnx.reseed
来重置内部的 RNG 状态。
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))
# We don't really need to call lazy_init because no extra params were created here,
# but it's a good practice to always add this line.
bridge.lazy_init(model, x)
y1, y2 = model(x), model(x)
assert not jnp.allclose(y1, y2) # Two runs yield different outputs!
# Reset the dropout RNG seed, so that next model run will be the same as the first.
nnx.reseed(model, dropout=0)
assert jnp.allclose(y1, model(x))
NNX 到 Linen#
如果您将 NNX 模块转换为 Linen,底层 NNX 模块的 RNG 状态仍然会是顶层 variables
的一部分。另一方面,Linen apply()
调用在每次调用时都接受不同的 RNG 密钥,这会重置内部 Linen 环境并允许生成不同的随机数据。
现在,这实际上取决于您的底层 NNX 模块是使用 RNG 状态还是使用传入的参数生成新的随机数据。幸运的是,nnx.Dropout
支持这两种方式——如果有传入密钥,则使用传入密钥,否则使用它自己的 RNG 状态。
这为您提供了两种处理 RNG 密钥的样式选项:
NNX 样式(推荐):让底层 NNX 状态管理 RNG 密钥,在
apply()
中不需要传入额外的密钥。这意味着需要额外几行代码来更改每次调用时的variables
,但一旦您的整个模型不再需要ToLinen
,事情就会变得更容易。Linen 样式:在每次
apply()
调用时,只需传入不同的 RNG 密钥。
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.to_linen(nnx.Dropout, rate=0.5)
variables = model.init({'dropout': jax.random.key(0)}, x)
# The NNX RNG state was stored inside `variables`
print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)
print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)
# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)
assert not jnp.allclose(y1, y2) # Every call yields different output!
# Linen style: Just pass different RNG keys for every `apply()` call.
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
assert not jnp.allclose(y3, y4) # Every call yields different output!
y5 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are same
The RNG key in state: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Number of key splits: 0
Number of key splits after y2: 2
NNX 变量类型与 Linen 集合#
当您想将一些变量作为一类进行分组时,在 Linen 中,您使用不同的集合;在 NNX 中,由于所有变量都应该是顶层 Python 属性,因此您使用不同的变量类型。
因此,在混合使用 Linen 和 NNX 模块时,Flax 必须知道 Linen 集合和 NNX 变量类型之间的 1 对 1 映射,以便 ToNNX
和 ToLinen
可以自动执行转换。
Flax 为此维护了一个注册表,并且它已经涵盖了所有 Flax 的内置 Linen 集合。您可以使用 nnx.register_variable_name_type_pair
注册 NNX 变量类型和 Linen 集合名称的额外映射。
Linen 到 NNX#
对于 Linen 模块的任何集合,ToNNX
会将其所有端点数组(即叶子)转换为 nnx.Variable
的子类型,可以从注册表获取,也可以动态创建。
(但是,我们仍然将整个集合保留为一个类属性,因为 Linen 模块可能在不同的集合中拥有重复的名称。)
class LinenMultiCollections(nn.Module):
out_dim: int
def setup(self):
self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim))
self.b = self.param('b', nn.zeros_init(), (self.out_dim,))
self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))
def __call__(self, x):
if not self.is_initializing():
self.count.value += 1
y = x @ self.w + self.b
self.sow('intermediates', 'dot_sum', jnp.sum(y))
return y
x = jax.random.normal(jax.random.key(42), (2, 4))
model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)
print(model.w) # Of type `nnx.Param` - note this is still under attribute `params`
print(model.b) # Of type `nnx.Param`
print(model.count) # Of type `counter` - auto-created type from the collection name
print(type(model.count))
y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger
print(model.dot_sum) # Of type `nnx.Intermediates`
Param(
value=Array([[ 0.35401407, 0.38010964, -0.20674096],
[-0.7356256 , 0.35613298, -0.5099556 ],
[-0.4783049 , 0.4310735 , 0.30137998],
[-0.6102254 , -0.2668519 , -1.053598 ]], dtype=float32)
)
Param(
value=Array([0., 0., 0.], dtype=float32)
)
counter(
value=Array(0, dtype=int32)
)
<class 'abc.counter'>
(Intermediate(
value=Array(6.932987, dtype=float32)
),)
您可以使用 nnx.split
快速将不同类型的 NNX 变量区分开来。
当您只想将某些变量设置为可训练时,这将非常有用。
# Separate variables of different types with nnx.split
CountType = type(model.count)
static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)
print('All Params:', list(params.keys()))
print('All Counters:', list(counter.keys()))
print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))
model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time
y = model(x, mutable=True) # still works!
All Params: ['b', 'w']
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']
NNX 到 Linen#
如果您定义了自定义 NNX 变量类型,您应该使用 nnx.register_variable_name_type_pair
注册它们的名字,以便它们被放到相应的集合中。
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('counts', Count, overwrite=True)
class NNXMultiCollections(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)
self.count = Count(jnp.array(0))
def __call__(self, x):
self.count += 1
return (x @ self.w.value) + self.lora(x)
xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(xkey, (2, 4))
model = bridge.to_linen(NNXMultiCollections, 4, 3)
var = model.init({'params': pkey, 'dropout': dkey}, x)
print('All Linen collections:', list(var.keys()))
print(var['params'])
All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']
{'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],
[ 0.17487915, -0.34043145, 0.24764155],
[ 0.6420431 , 0.6220095 , -0.44769976],
[ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}
分区元数据#
Flax 在原始 JAX 数组上使用了一个元数据包装器,用于注释变量应该如何分片。
在 Linen 中,这是一个可选功能,通过在初始化器上使用 nn.with_partitioning
来触发(有关更多信息,请参阅 Linen 分区元数据指南)。在 NNX 中,由于所有 NNX 变量都由 nnx.Variable
类包装,因此该类也将保存分片注释。
如果使用内置注释方法(即 Linen 的 nn.with_partitioning
和 NNX 的 nnx.with_partitioning
),bridge.ToNNX
和 bridge.ToLinen
API 会自动转换分片注释。
Linen 到 NNX#
即使您在 Linen 模块中没有使用任何分区元数据,变量 JAX 数组也会被转换为 nnx.Variable
,它在内部包装了真正的 JAX 数组。
如果您使用 nn.with_partitioning
来注释 Linen 模块的变量,该注释将被转换为对应 nnx.Variable
中的 .sharding
字段。
然后,您可以使用 nnx.with_sharding_constraint
在 jax.jit
编译的函数中明确地将数组放到带注释的分区中,以便使用每个数组在正确分片下初始化整个模型。
class LinenDotWithPartitioning(nn.Module):
out_dim: int
@nn.compact
def __call__(self, x):
w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),
('in', 'out')),
(x.shape[-1], self.out_dim))
return x @ w
@nnx.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(
bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)
state = nnx.state(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(model, sharded_state)
return model
print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
axis_names=('in', 'out'))
x = jax.random.normal(jax.random.key(42), (4, 32))
with mesh:
model = create_sharded_nnx_module(x)
print(type(model.w)) # `nnx.Param`
print(model.w.sharding) # The partition annotation attached with `w`
print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh
We have 8 fake JAX devices now to partition this model...
<class 'flax.nnx.variables.Param'>
('in', 'out')
GSPMDSharding({devices=[2,4]<=[8]})
NNX 到 Linen#
如果您没有使用 nnx.Variable
的任何元数据功能(即没有分片注释,没有注册的钩子),转换后的 Linen 模块将不会为您的 NNX 变量添加元数据包装器,您不必担心它。
但是,如果您在 NNX 变量中添加了分片注释,ToLinen
将它们转换为默认的 Linen 分区元数据类,名为 bridge.NNXMeta
,保留了您放在 NNX 变量中的所有元数据。
与任何 Linen 元数据包装器一样,您可以使用 linen.unbox()
获取原始的 JAX 数组树。
class NNXDotWithParititioning(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))
def __call__(self, x: jax.Array):
return x @ self.w
x = jax.random.normal(jax.random.key(42), (4, 32))
@jax.jit
def create_sharded_variables(key, x):
model = bridge.to_linen(NNXDotWithParititioning, 32, 64)
variables = model.init(key, x)
# A `NNXMeta` wrapper of the underlying `nnx.Param`
assert type(variables['params']['w']) == bridge.NNXMeta
# The annotation coming from the `nnx.Param` => (in, out)
assert variables['params']['w'].metadata['sharding'] == ('in', 'out')
unboxed_variables = nn.unbox(variables)
variable_pspecs = nn.get_partition_spec(variables)
assert isinstance(unboxed_variables['params']['w'], jax.Array)
assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')
sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint,
nn.unbox(variables),
nn.get_partition_spec(variables))
return sharded_vars
with mesh:
variables = create_sharded_variables(jax.random.key(0), x)
# The underlying JAX array is sharded across the 2x4 mesh
print(variables['params']['w'].sharding)
GSPMDSharding({devices=[2,4]<=[8]})
提升的变换#
一般来说,如果您想对 nnx.bridge
转换的模块应用 Linen/NNX 风格的提升变换,只需按照通常的 Linen/NNX 语法进行。
对于 Linen 风格的变换,请注意 bridge.ToLinen
是顶级模块类,因此您可能希望将其用作变换的第一个参数(在大多数情况下,变换需要是 linen.Module
类)。
Linen 到 NNX#
NNX 风格的提升变换类似于 JAX 变换,它们作用于函数。
class NNXVmapped(nnx.Module):
def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):
self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)
self.vmap_axis_size = vmap_axis_size
def __call__(self, x):
@nnx.split_rngs(splits=self.vmap_axis_size)
@nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size)
def vmap_fn(submodule, x):
return submodule(x)
return vmap_fn(self.linen_dot, x)
x = jax.random.normal(jax.random.key(0), (4, 32))
model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)
print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped
y = model(x)
print(y.shape)
(4, 32, 64)
(4, 64)
NNX 到 Linen#
请注意,bridge.ToLinen
是顶级模块类,因此您可能希望将其用作变换的第一个参数(在大多数情况下,变换需要是 linen.Module
类)。
此外,由于 bridge.ToLinen
引入了这个额外的 nnx
集合,您需要在使用轴变换变换(linen.vmap
,linen.scan
等)时标记它,以确保它们被传递到内部。
class LinenVmapped(nn.Module):
dout: int
@nn.compact
def __call__(self, x):
inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0, 'nnx': None}, split_rngs={'params': True}
)(nnx.Linear, args=(x.shape[-1], self.dout))
return inner(x)
x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenVmapped(64)
var = model.init(jax.random.key(0), x)
print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped
y = model.apply(var, x)
print(y.shape)
(4, 32, 64)
(4, 64)