从 Linen 到 NNX 的演变#
本指南将引导您了解 Flax Linen 和 Flax NNX 模型之间的差异,以及并排比较,以帮助您将代码从 Linen API 迁移到 NNX。
在阅读本指南之前,强烈建议您通读 Flax NNX 基础知识,以了解 Flax NNX 的核心概念和代码示例。
本指南主要介绍如何将任意 Linen 代码转换为 NNX。如果您想安全地逐步转换代码库,请查看允许您 将 NNX 和 Linen 代码一起使用 的指南。
基本模块定义#
Linen 和 NNX 都使用 Module
作为表达神经库层的默认方式。Linen 和 NNX 模块之间有两个基本区别
**无状态 vs. 有状态**: Linen 模块实例是无状态的:变量从纯函数式的
.init()
调用返回,并单独管理。然而,NNX 模块将自己的变量作为此 Python 对象的属性。**延迟 vs. 急切**: Linen 模块仅在实际看到输入时才分配空间来创建变量。而 NNX 模块实例在其被实例化的瞬间创建其变量,而无需看到示例输入。
Linen 可以使用
@nn.compact
装饰器在一个方法中定义模型,并使用来自输入样本的形状推断,而 NNX 模块通常要求额外的形状信息,以便在__init__
期间创建所有参数,并在__call__
中单独定义计算。
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
变量创建#
要为 Linen 模型生成模型参数,您可以使用 jax.random.key
以及模型将接收的一些示例输入来调用 init
方法。结果是一个嵌套的 JAX 数组字典,需要单独维护和传递。
在 NNX 中,当用户实例化模型时,模型参数会自动初始化,变量存储在模块(或其子模块)中,作为属性。您仍然需要为它提供一个 RNG 密钥,但该密钥将被封装在 nnx.Rngs
类中并存储在其中,并在需要时生成更多 RNG 密钥。
如果您想以无状态的字典形式访问 NNX 模型参数,以便进行检查点保存或模型手术,请查看 NNX 拆分/合并 API。
model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# parameters were already initialized during model instantiation
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
训练步骤和编译#
现在我们编写一个训练步骤并使用 JAX 实时编译进行编译。请注意这里的一些区别
Linen 使用
@jax.jit
编译训练步骤,而 NNX 使用@nnx.jit
。jax.jit
仅接受纯无状态参数,但nnx.jit
允许参数是有状态的 NNX 模块。这大大减少了训练步骤所需的代码行数。类似地,Linen 使用
jax.grad()
返回一个原始梯度字典,而 NNX 可以使用nnx.grad
返回模块作为 NNXState
字典的梯度。要使用 NNX 与常规jax.grad
,您需要使用 NNX 拆分/合并 API。如果您已经在使用 Optax 优化器,例如
optax.adamw
(而不是这里显示的原始jax.tree.map
计算),请查看 nnx.Optimizer 示例,以了解一种更简洁的训练和更新模型的方法。
Linen 训练步骤需要返回一个参数树,作为下一步的输入。另一方面,NNX 的步骤不需要返回任何内容,因为
model
已在nnx.jit
中就地更新。NNX 模块是有状态的,并在内部自动跟踪一些内容,例如 RNG 密钥和 BatchNorm 统计信息。这就是为什么您不需要在每一步都显式传递 RNG 密钥。请注意,您可以使用 nnx.reseed 重置其底层 RNG 状态。
在 Linen 中,您需要显式定义并传入一个参数
training
来控制nn.Dropout
的行为(即它的deterministic
标志,这意味着随机 dropout 仅在training=True
时发生)。在 NNX 中,您可以调用model.train()
将nnx.Dropout
自动切换到训练模式。相反,调用model.eval()
关闭训练模式。您可以在其 API 参考 中详细了解此 API 的作用。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
集合和变量类型#
Linen 和 NNX API 之间的一个关键区别在于我们如何将变量分组到类别中。在 Linen 中,我们使用不同的集合;在 NNX 中,由于所有变量都应该是顶级 Python 属性,因此您使用不同的变量类型。
您可以自由地创建自己的变量类型,作为 nnx.Variable
的子类。
对于所有内置的 Flax Linen 层和集合,NNX 已经创建了相应的层和变量类型。例如
nn.Dense
创建params
->nnx.Linear
创建nnx.Param
。
nn.BatchNorm
创建batch_stats
->nnx.BatchNorm
创建nnx.BatchStats
。
linen.Module.sow()
创建intermediates
->nnx.Module.sow()
创建nnx.Intermediates
。
您也可以简单地通过将其分配给模块属性来获取中间值,例如
self.sowed = nnx.Intermediates(x)
。这将类似于 Linen 的self.variable('intermediates' 'sowed', lambda: x)
。
class Block(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
self.batchnorm = nn.BatchNorm(momentum=0.99)
self.count = self.variable('counter', 'count',
lambda: jnp.zeros((), jnp.int32))
@nn.compact
def __call__(self, x, training: bool):
x = self.dense(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)
return x
x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape # (4, )
variables['counter']['count'] # 1
class Counter(nnx.Variable): pass
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
self.count = Counter(jnp.array(0))
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count += 1
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
model.count # Counter(value=...)
如果您想从变量树中提取某些数组,您可以在 Linen 中访问特定的字典路径,或者使用 nnx.split
在 NNX 中区分不同的类型。下面的代码是一个更简单的示例,查看 过滤器 API 指南,了解更复杂的过滤表达式。
params, batch_stats, counter = (
variables['params'], variables['batch_stats'], variables['counter'])
params.keys() # ['dense', 'batchnorm']
batch_stats.keys() # ['batchnorm']
counter.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
model, nnx.Param, nnx.BatchStat, Counter)
params.keys() # ['batchnorm', 'linear']
batch_stats.keys() # ['batchnorm']
count.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)
使用多个方法#
在本节中,我们将看看如何在两个框架中使用多个方法。作为一个例子,我们将实现一个具有三个方法的自动编码器模型:encode
、decode
和 __call__
。
与以前一样,我们定义编码器和解码器层,而无需传入输入形状,因为模块参数将使用 Linen 中的形状推断进行延迟初始化。在 NNX 中,我们必须传入输入形状,因为模块参数将在没有形状推断的情况下急切地初始化。
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
变量结构如下
# variables['params']
{
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
要调用除 __call__
以外的方法,在 Linen 中,您仍然需要使用 apply
API,而在 NNX 中,您可以直接调用该方法。
z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))
提升变换#
Flax API 提供了一组变换,我们称之为提升变换,它们将 JAX 变换包装起来,以便可以与模块一起使用。
Linen 中的大多数变换在 NNX 中没有太大变化。请参阅下一节(在层上扫描),了解代码差异更大的情况。
首先,我们将定义一个名为 RNNCell
的模块,它将包含 RNN 单步的逻辑。我们还将定义一个名为 initial_state
的方法,用于初始化 RNN 的状态(也称为 carry
)。与 jax.lax.scan
类似,RNNCell.__call__
方法将是一个函数,它接收状态和输入,并返回新的状态和输出。在本例中,状态和输出是相同的。
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个名为 RNN
的模块,它将包含整个 RNN 的逻辑。
在 Linen 中,我们将使用 nn.scan
来定义一个新的临时类型,它将封装 RNNCell
。在这个过程中,我们还将指定 nn.scan
将广播 params
集合(所有步骤共享相同的参数),并且不拆分 params
的 rng 流(因此所有步骤都使用相同的参数初始化),最后,我们将指定我们希望扫描在输入的第二轴上运行,并将输出沿着第二轴堆叠。然后,我们将立即使用此临时类型来创建一个提升的 RNNCell
的实例,并使用它来创建 carry
并运行 __call__
方法,该方法将 scan
遍历序列。
在 NNX 中,我们定义了一个名为 scan_fn
的扫描函数,它将使用在 __init__
中定义的 RNNCell
来扫描序列,并显式地设置 in_axes=(nnx.Carry, None, 1)
,Carry
表示 carry
参数将是状态,None
表示 cell
将广播到所有步骤,而 1
表示 x
将在轴 1 上进行扫描。
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
跨层扫描#
一般来说,Linen 和 NNX 的提升转换应该看起来相同。但是,NNX 提升转换旨在更接近其低级 JAX 对应物,因此我们在某些 Linen 提升转换中放弃了一些假设。这种跨层扫描用例将是一个很好的例子来说明这一点。
跨层扫描是一种技术,其中,我们希望将输入通过一系列 N 个重复的层运行,将每个层的输出作为下一个层的输入传递。这种模式可以显着减少大型模型的编译时间。在本例中,我们将在顶层模块 MLP
中重复模块 Block
5 次。
在 Linen 中,我们在模块 Block
上应用一个 nn.scan
来创建一个更大的模块 ScanBlock
,它包含 5 个 Block
。它将在初始化时自动创建一个形状为 (5, 64, 64)
的大型参数,并在调用时遍历每个 (64, 64)
切片,总共 5 次,就像 jax.lax.scan
一样。
但如果你仔细想想,其实在初始化时并不需要 jax.lax.scan
操作。那里发生的事情更像是 jax.vmap
操作——你得到了一个接受 (in_dim, out_dim)
的 Block
,并且你对它进行 num_layers
次的“vmap”以创建一个更大的数组。
在 NNX 中,我们利用模型初始化和运行代码完全解耦的事实,并使用 nnx.vmap
来初始化底层块,以及 nnx.scan
来运行模型输入。
有关 NNX 转换的更多信息,请查看 转换指南。
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
model = MLP(64, num_layers=5)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
在本例中,还需要解释一些其他细节
`nnx.split_rngs` 装饰器是什么?NNX 转换完全独立于 RNG 状态,这使得它们的行為更像 JAX 转换,但与处理 RNG 状态的 Linen 转换有所区别。为了恢复此功能,
nnx.split_rngs
装饰器允许你在将 RNG 传递给装饰的函数之前拆分Rngs
,并在之后将其“降低”,以便它们可以在外部使用。在这里,我们拆分 RNG 密钥,因为
jax.vmap
和jax.lax.scan
要求如果每个内部操作都需要自己的密钥,则提供一个 RNG 密钥列表。因此,对于MLP
中的 5 层,我们在下降到 JAX 转换之前,从其参数中拆分并提供 5 个不同的 RNG 密钥。请注意,实际上
create_block()
知道它需要创建 5 层,正是因为它看到了 5 个 RNG 密钥,因为in_axes=(0,)
表示vmap
将查看第一个参数的第一个维度以了解它将映射的尺寸。对于
forward()
也是如此,它查看第一个参数(也称为model
)中的变量,以确定它需要扫描多少次。nnx.split_rngs
在这里实际上拆分了model
中的 RNG 状态。(如果Block
没有 dropout,则不需要nnx.split_rngs
行,因为它无论如何都不会使用任何 RNG 密钥。)
为什么 NNX 中的 `Block` 不需要接收和返回那个额外的虚拟值?这是 jax.lax.scan 的要求。NNX 简化了这一点,因此现在你可以选择忽略第二个输入/输出,如果你设置了
out_axes=nnx.Carry
而不是默认的(nnx.Carry, 0)
。这是 NNX 转换与 JAX 转换 API 不同的少数情况之一。
这需要更多的代码行,但它更精确地表达了每次发生的事情。由于 NNX 提升转换变得更接近 JAX API,因此建议在使用 NNX 版本之前,对底层的 JAX 转换有很好的了解。
现在看一下两边的变量树
# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
在 NNX 中使用 `TrainState`#
Flax 提供了一个方便的 TrainState
数据类,用于捆绑模型、参数和优化器。这在 NNX 时代并不真正需要,但本节将展示如何围绕它构建 NNX 代码,以满足任何向后兼容需求。
在 NNX 中,我们必须首先在模型上调用 nnx.split
,以获取分离的 GraphDef
和 State
对象。我们可以传入 nnx.Param
将所有可训练参数过滤到单个 State
中,并传入 ...
用于剩余的变量。我们还需要子类化 TrainState
以添加一个用于其他变量的字段。然后,我们可以将 GraphDef.apply
作为应用函数,State
作为参数和其他变量,以及一个优化器作为参数传递给 TrainState
构造函数。需要注意的是,GraphDef.apply
将接收 State
作为参数,并返回一个可调用的函数。此函数可以在输入上调用,以输出模型的 logits,以及更新后的 GraphDef
和 State
对象。请注意,我们还使用了 @jax.jit
,因为我们没有将 NNX 模块传递给 train_step
。
from flax.training import train_state
sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
params,
other_variables
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state