从 Linen 到 NNX 的演变#

本指南将引导您了解 Flax Linen 和 Flax NNX 模型之间的差异,以及并排比较,以帮助您将代码从 Linen API 迁移到 NNX。

在阅读本指南之前,强烈建议您通读 Flax NNX 基础知识,以了解 Flax NNX 的核心概念和代码示例。

本指南主要介绍如何将任意 Linen 代码转换为 NNX。如果您想安全地逐步转换代码库,请查看允许您 将 NNX 和 Linen 代码一起使用 的指南。

基本模块定义#

Linen 和 NNX 都使用 Module 作为表达神经库层的默认方式。Linen 和 NNX 模块之间有两个基本区别

  • **无状态 vs. 有状态**: Linen 模块实例是无状态的:变量从纯函数式的 .init() 调用返回,并单独管理。然而,NNX 模块将自己的变量作为此 Python 对象的属性。

  • **延迟 vs. 急切**: Linen 模块仅在实际看到输入时才分配空间来创建变量。而 NNX 模块实例在其被实例化的瞬间创建其变量,而无需看到示例输入。

    • Linen 可以使用 @nn.compact 装饰器在一个方法中定义模型,并使用来自输入样本的形状推断,而 NNX 模块通常要求额外的形状信息,以便在 __init__ 期间创建所有参数,并在 __call__ 中单独定义计算。

import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x
from flax import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

变量创建#

要为 Linen 模型生成模型参数,您可以使用 jax.random.key 以及模型将接收的一些示例输入来调用 init 方法。结果是一个嵌套的 JAX 数组字典,需要单独维护和传递。

在 NNX 中,当用户实例化模型时,模型参数会自动初始化,变量存储在模块(或其子模块)中,作为属性。您仍然需要为它提供一个 RNG 密钥,但该密钥将被封装在 nnx.Rngs 类中并存储在其中,并在需要时生成更多 RNG 密钥。

如果您想以无状态的字典形式访问 NNX 模型参数,以便进行检查点保存或模型手术,请查看 NNX 拆分/合并 API

model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]

assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))


# parameters were already initialized during model instantiation

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

训练步骤和编译#

现在我们编写一个训练步骤并使用 JAX 实时编译进行编译。请注意这里的一些区别

  • Linen 使用 @jax.jit 编译训练步骤,而 NNX 使用 @nnx.jitjax.jit 仅接受纯无状态参数,但 nnx.jit 允许参数是有状态的 NNX 模块。这大大减少了训练步骤所需的代码行数。

  • 类似地,Linen 使用 jax.grad() 返回一个原始梯度字典,而 NNX 可以使用 nnx.grad 返回模块作为 NNX State 字典的梯度。要使用 NNX 与常规 jax.grad,您需要使用 NNX 拆分/合并 API

    • 如果您已经在使用 Optax 优化器,例如 optax.adamw(而不是这里显示的原始 jax.tree.map 计算),请查看 nnx.Optimizer 示例,以了解一种更简洁的训练和更新模型的方法。

  • Linen 训练步骤需要返回一个参数树,作为下一步的输入。另一方面,NNX 的步骤不需要返回任何内容,因为 model 已在 nnx.jit 中就地更新。

  • NNX 模块是有状态的,并在内部自动跟踪一些内容,例如 RNG 密钥和 BatchNorm 统计信息。这就是为什么您不需要在每一步都显式传递 RNG 密钥。请注意,您可以使用 nnx.reseed 重置其底层 RNG 状态。

  • 在 Linen 中,您需要显式定义并传入一个参数 training 来控制 nn.Dropout 的行为(即它的 deterministic 标志,这意味着随机 dropout 仅在 training=True 时发生)。在 NNX 中,您可以调用 model.train()nnx.Dropout 自动切换到训练模式。相反,调用 model.eval() 关闭训练模式。您可以在其 API 参考 中详细了解此 API 的作用。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)

  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(inputs)




    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  _, params, rest = nnx.split(model, nnx.Param, ...)
  params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
  nnx.update(model, nnx.GraphState.merge(params, rest))

集合和变量类型#

Linen 和 NNX API 之间的一个关键区别在于我们如何将变量分组到类别中。在 Linen 中,我们使用不同的集合;在 NNX 中,由于所有变量都应该是顶级 Python 属性,因此您使用不同的变量类型。

您可以自由地创建自己的变量类型,作为 nnx.Variable 的子类。

对于所有内置的 Flax Linen 层和集合,NNX 已经创建了相应的层和变量类型。例如

  • nn.Dense 创建 params -> nnx.Linear 创建 nnx.Param

  • nn.BatchNorm 创建 batch_stats -> nnx.BatchNorm 创建 nnx.BatchStats

  • linen.Module.sow() 创建 intermediates -> nnx.Module.sow() 创建 nnx.Intermediates

    • 您也可以简单地通过将其分配给模块属性来获取中间值,例如 self.sowed = nnx.Intermediates(x)。这将类似于 Linen 的 self.variable('intermediates' 'sowed', lambda: x)

class Block(nn.Module):
  features: int
  def setup(self):
    self.dense = nn.Dense(self.features)
    self.batchnorm = nn.BatchNorm(momentum=0.99)
    self.count = self.variable('counter', 'count',
                                lambda: jnp.zeros((), jnp.int32))


  @nn.compact
  def __call__(self, x, training: bool):
    x = self.dense(x)
    x = self.batchnorm(x, use_running_average=not training)
    self.count.value += 1
    x = jax.nn.relu(x)
    return x

x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape         # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape  # (4, )
variables['counter']['count']                        # 1
class Counter(nnx.Variable): pass

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )
    self.count = Counter(jnp.array(0))

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)
    self.count += 1
    x = jax.nn.relu(x)
    return x



model = Block(4, 4, rngs=nnx.Rngs(0))

model.linear.kernel   # Param(value=...)
model.batchnorm.mean  # BatchStat(value=...)
model.count           # Counter(value=...)

如果您想从变量树中提取某些数组,您可以在 Linen 中访问特定的字典路径,或者使用 nnx.split 在 NNX 中区分不同的类型。下面的代码是一个更简单的示例,查看 过滤器 API 指南,了解更复杂的过滤表达式。

params, batch_stats, counter = (
  variables['params'], variables['batch_stats'], variables['counter'])
params.keys()       # ['dense', 'batchnorm']
batch_stats.keys()  # ['batchnorm']
counter.keys()      # ['count']

# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
  model, nnx.Param, nnx.BatchStat, Counter)
params.keys()       # ['batchnorm', 'linear']
batch_stats.keys()  # ['batchnorm']
count.keys()        # ['count']

# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)

使用多个方法#

在本节中,我们将看看如何在两个框架中使用多个方法。作为一个例子,我们将实现一个具有三个方法的自动编码器模型:encodedecode__call__

与以前一样,我们定义编码器和解码器层,而无需传入输入形状,因为模块参数将使用 Linen 中的形状推断进行延迟初始化。在 NNX 中,我们必须传入输入形状,因为模块参数将在没有形状推断的情况下急切地初始化。

class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):



  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))

变量结构如下

# variables['params']
{
  decoder: {
      bias: (784,),
      kernel: (256, 784),
  },
  encoder: {
      bias: (256,),
      kernel: (784, 256),
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

要调用除 __call__ 以外的方法,在 Linen 中,您仍然需要使用 apply API,而在 NNX 中,您可以直接调用该方法。

z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))

提升变换#

Flax API 提供了一组变换,我们称之为提升变换,它们将 JAX 变换包装起来,以便可以与模块一起使用。

Linen 中的大多数变换在 NNX 中没有太大变化。请参阅下一节(在层上扫描),了解代码差异更大的情况。

首先,我们将定义一个名为 RNNCell 的模块,它将包含 RNN 单步的逻辑。我们还将定义一个名为 initial_state 的方法,用于初始化 RNN 的状态(也称为 carry)。与 jax.lax.scan 类似,RNNCell.__call__ 方法将是一个函数,它接收状态和输入,并返回新的状态和输出。在本例中,状态和输出是相同的。

class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个名为 RNN 的模块,它将包含整个 RNN 的逻辑。

在 Linen 中,我们将使用 nn.scan 来定义一个新的临时类型,它将封装 RNNCell。在这个过程中,我们还将指定 nn.scan 将广播 params 集合(所有步骤共享相同的参数),并且不拆分 params 的 rng 流(因此所有步骤都使用相同的参数初始化),最后,我们将指定我们希望扫描在输入的第二轴上运行,并将输出沿着第二轴堆叠。然后,我们将立即使用此临时类型来创建一个提升的 RNNCell 的实例,并使用它来创建 carry 并运行 __call__ 方法,该方法将 scan 遍历序列。

在 NNX 中,我们定义了一个名为 scan_fn 的扫描函数,它将使用在 __init__ 中定义的 RNNCell 来扫描序列,并显式地设置 in_axes=(nnx.Carry, None, 1)Carry 表示 carry 参数将是状态,None 表示 cell 将广播到所有步骤,而 1 表示 x 将在轴 1 上进行扫描。

class RNN(nn.Module):
  hidden_size: int

  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(
      RNNCell, variable_broadcast='params',
      split_rngs={'params': False}, in_axes=1, out_axes=1
    )(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))

y = model(x)

跨层扫描#

一般来说,Linen 和 NNX 的提升转换应该看起来相同。但是,NNX 提升转换旨在更接近其低级 JAX 对应物,因此我们在某些 Linen 提升转换中放弃了一些假设。这种跨层扫描用例将是一个很好的例子来说明这一点。

跨层扫描是一种技术,其中,我们希望将输入通过一系列 N 个重复的层运行,将每个层的输出作为下一个层的输入传递。这种模式可以显着减少大型模型的编译时间。在本例中,我们将在顶层模块 MLP 中重复模块 Block 5 次。

在 Linen 中,我们在模块 Block 上应用一个 nn.scan 来创建一个更大的模块 ScanBlock,它包含 5 个 Block。它将在初始化时自动创建一个形状为 (5, 64, 64) 的大型参数,并在调用时遍历每个 (64, 64) 切片,总共 5 次,就像 jax.lax.scan 一样。

但如果你仔细想想,其实在初始化时并不需要 jax.lax.scan 操作。那里发生的事情更像是 jax.vmap 操作——你得到了一个接受 (in_dim, out_dim)Block,并且你对它进行 num_layers 次的“vmap”以创建一个更大的数组。

在 NNX 中,我们利用模型初始化和运行代码完全解耦的事实,并使用 nnx.vmap 来初始化底层块,以及 nnx.scan 来运行模型输入。

有关 NNX 转换的更多信息,请查看 转换指南

class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int




  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

model = MLP(64, num_layers=5)
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)

model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))

在本例中,还需要解释一些其他细节

  • `nnx.split_rngs` 装饰器是什么?NNX 转换完全独立于 RNG 状态,这使得它们的行為更像 JAX 转换,但与处理 RNG 状态的 Linen 转换有所区别。为了恢复此功能,nnx.split_rngs 装饰器允许你在将 RNG 传递给装饰的函数之前拆分 Rngs,并在之后将其“降低”,以便它们可以在外部使用。

    • 在这里,我们拆分 RNG 密钥,因为 jax.vmapjax.lax.scan 要求如果每个内部操作都需要自己的密钥,则提供一个 RNG 密钥列表。因此,对于 MLP 中的 5 层,我们在下降到 JAX 转换之前,从其参数中拆分并提供 5 个不同的 RNG 密钥。

    • 请注意,实际上 create_block() 知道它需要创建 5 层,正是因为它看到了 5 个 RNG 密钥,因为 in_axes=(0,) 表示 vmap 将查看第一个参数的第一个维度以了解它将映射的尺寸。

    • 对于 forward() 也是如此,它查看第一个参数(也称为 model)中的变量,以确定它需要扫描多少次。nnx.split_rngs 在这里实际上拆分了 model 中的 RNG 状态。(如果 Block 没有 dropout,则不需要 nnx.split_rngs 行,因为它无论如何都不会使用任何 RNG 密钥。)

  • 为什么 NNX 中的 `Block` 不需要接收和返回那个额外的虚拟值?这是 jax.lax.scan 的要求。NNX 简化了这一点,因此现在你可以选择忽略第二个输入/输出,如果你设置了 out_axes=nnx.Carry 而不是默认的 (nnx.Carry, 0)

    • 这是 NNX 转换与 JAX 转换 API 不同的少数情况之一。

这需要更多的代码行,但它更精确地表达了每次发生的事情。由于 NNX 提升转换变得更接近 JAX API,因此建议在使用 NNX 版本之前,对底层的 JAX 转换有很好的了解。

现在看一下两边的变量树

# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
  ScanBlock_0: {
    Dense_0: {
      bias: (5, 64),
      kernel: (5, 64, 64),
    },
  },
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
  'blocks': {
    'linear': {
      'bias': VariableState(type=Param, value=(5, 64)),
      'kernel': VariableState(type=Param, value=(5, 64, 64))
    }
  }
})

在 NNX 中使用 `TrainState`#

Flax 提供了一个方便的 TrainState 数据类,用于捆绑模型、参数和优化器。这在 NNX 时代并不真正需要,但本节将展示如何围绕它构建 NNX 代码,以满足任何向后兼容需求。

在 NNX 中,我们必须首先在模型上调用 nnx.split,以获取分离的 GraphDefState 对象。我们可以传入 nnx.Param 将所有可训练参数过滤到单个 State 中,并传入 ... 用于剩余的变量。我们还需要子类化 TrainState 以添加一个用于其他变量的字段。然后,我们可以将 GraphDef.apply 作为应用函数,State 作为参数和其他变量,以及一个优化器作为参数传递给 TrainState 构造函数。需要注意的是,GraphDef.apply 将接收 State 作为参数,并返回一个可调用的函数。此函数可以在输入上调用,以输出模型的 logits,以及更新后的 GraphDefState 对象。请注意,我们还使用了 @jax.jit,因为我们没有将 NNX 模块传递给 train_step

from flax.training import train_state

sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']




state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      inputs, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state

model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

class TrainState(train_state.TrainState):
  other_variables: nnx.State

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  other_variables=other_variables,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, other_variables):
    logits, (graphdef, new_state) = state.apply_fn(
      params,
      other_variables

    )(inputs) # <== inputs
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params, state.other_variables)


  state = state.apply_gradients(grads=grads)

  return state