亚麻基础#

Flax NNX 是一个新的简化 API,旨在简化在 JAX 中创建、检查、调试和分析神经网络。它通过添加对 Python 引用语义的一流支持来实现这一点。这允许用户使用常规 Python 对象来表达他们的模型,这些对象被建模为 PyGraphs(而不是 pytrees),从而实现引用共享和可变性。这种 API 设计应该让 PyTorch 或 Keras 用户感到宾至如归。

在本指南中,您将了解

  • Flax nnx.Module 系统:创建和初始化自定义 Linear 层的示例。

    • 有状态计算:创建 Flax nnx.Variable 并更新其值(例如,正向传递期间所需的 state 更新)的示例。

    • 嵌套的 nnx.Module:包含 Linearnnx.Dropoutnnx.BatchNorm 层的 MLP 示例。

    • 模型手术:将模型内的自定义 Linear 层替换为自定义 LoraLinear 层的示例。

  • Flax 转换:使用 nnx.jit 进行自动状态管理的示例。

  • Flax NNX 函数式 API:具有 nnx.Param 的自定义 StatefulLinear 层的示例,可以对 state 进行细粒度控制。

设置#

使用 pip 安装 Flax 并赋予必要的依赖项

# ! pip install -U flax treescope
from flax import nnx
import jax
import jax.numpy as jnp

Flax nnx.Module 系统#

Flax nnx.ModuleFlax LinenHaiku 中的其他 Module 系统之间的主要区别在于,在 NNX 中,一切都 明确。这意味着,除其他外,

  1. nnx.Module 本身直接保存 state(例如参数)。

  2. 用户通过线程处理 PRNG state。

  3. 必须在初始化时提供所有形状信息(无形状推断)。

让我们从创建一个 Linear nnx.Module 开始。以下代码显示了

  • 动态 state 通常存储在 nnx.Param 中,而静态 state(NNX 未处理的所有类型),例如整数或字符串,则直接存储。

  • 类型为 jax.Arraynumpy.ndarray 的属性也被视为动态 state,尽管将它们存储在 nnx.Variable 中(例如 Param)更可取。

  • nnx.Rngs 对象可用于根据传递给构造函数的根 PRNG 密钥获取新的唯一密钥。

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

另请注意

  • 可以使用 value 属性访问 nnx.Variable 的内部值,但为了方便起见,它们实现了所有数值运算符,并且可以直接在算术表达式中使用(如上面的代码所示)。

要初始化 Flax nnx.Module,只需调用构造函数,Module 的所有参数通常都会急切创建。由于 nnx.Module 保存自己的 state 方法,因此您可以直接调用它们,而无需单独的 apply 方法。这对于调试非常方便,允许您直接检查模型的整个结构。

model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)
[[1.245453   0.74195766 0.8553282  0.6763327  1.2617068 ]]
(正在加载...)

上面的 nnx.display 可视化是使用出色的 Treescope 库生成的。

有状态计算#

实现层(例如 nnx.BatchNorm)需要在正向传递期间执行 state 更新。在 Flax NNX 中,您只需要创建一个 nnx.Variable,并在正向传递期间更新其 .value

class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)

在 JAX 中通常避免使用可变引用。但 Flax NNX 提供了健全的机制来处理它们,如本指南后面部分所示。

嵌套的 nnx.Module#

Flax nnx.Module 可用于以嵌套结构组合其他 Module。这些可以作为属性直接分配,或者在任何(嵌套的)pytree 类型的属性内分配,例如 listdicttuple 等。

以下示例展示了如何通过子类化 nnx.Module 来定义一个简单的 MLP。该模型包含两个 Linear 层、一个 nnx.Dropout 层和一个 nnx.BatchNorm

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)
(正在加载...)

在 Flax 中,nnx.Dropout 是一个有状态的模块,它存储一个 nnx.Rngs 对象,以便它可以在前向传递过程中生成新的掩码,而无需用户每次都传递新的密钥。

模型手术#

Flax nnx.Module 默认情况下是可变的。这意味着它们的结构可以在任何时候更改,这使得模型手术非常容易,因为任何子 Module 属性都可以被其他任何东西替换,例如新的 Module、现有的共享 Module、不同类型的 Module 等等。此外,nnx.Variable 也可以被修改或替换/共享。

以下示例展示了如何将之前示例中 MLP 模型中的 Linear 层替换为 LoraLinear

class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)
(正在加载...)

Flax 变换#

Flax NNX 变换 (transforms) 扩展了 JAX 变换 以支持 nnx.Module 和其他对象。它们作为其等效 JAX 对应物的超集,增加了对对象状态的感知,并提供了用于变换它的额外 API。

Flax 变换的主要特性之一是保留了引用语义,这意味着只要在变换规则内合法,在变换内部发生的任何对象图变异都会传播到外部。在实践中,这意味着 Flax 程序可以使用命令式代码来表达,从而极大地简化了用户体验。

在以下示例中,您定义了一个 train_step 函数,它接受一个 MLP 模型、一个 nnx.Optimizer 和一批数据,并返回该步的损失。损失和梯度是使用 nnx.value_and_grad 变换在 loss_fn 上计算的。梯度被传递给优化器的 nnx.Optimizer.update 方法,以更新 model 的参数。

import optax

# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # In place updates.

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000279, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)

这个示例中有两件事值得一提

  1. 对每个 nnx.BatchNormnnx.Dropout 层状态的更新会自动从 loss_fn 内部传播到 train_step,一直传播到外部的 model 引用。

  2. optimizer 持有一个对 model 的可变引用 - 这种关系在 train_step 函数内部被保留,这使得可以使用优化器单独更新模型的参数成为可能。

nnx.scan 在层上#

下一个示例使用 Flax nnx.vmap 创建多个 MLP 层的堆栈,并使用 nnx.scan 迭代地将堆栈中的每一层应用于输入。

在下面的代码中注意以下几点

  1. 自定义的 create_model 函数接受一个键并返回一个 MLP 对象,因为您创建了五个键并使用 nnx.vmapcreate_model 上,创建了一个由 5 个 MLP 对象组成的堆栈。

  2. The nnx.scan 用于迭代地将堆栈中的每个 MLP 应用于输入 x

  3. The nnx.scan (有意地) 偏离了 jax.lax.scan,而是模仿了 nnx.vmap,后者更具表现力。 nnx.scan 允许指定多个输入、每个输入/输出的扫描轴以及进位的位置。

  4. The nnx.BatchNormnnx.Dropout 层的状态更新会自动由 nnx.scan 传播。

@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
(正在加载...)

Flax NNX 变换如何实现这一点?为了理解 Flax NNX 对象如何与 JAX 变换交互,下一节解释了 Flax NNX 函数式 API。

Flax 函数式 API#

Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了清晰的界限。它还允许对 Flax Linen 和 Haiku 用户所习惯的状态进行相同程度的细粒度控制。Flax NNX 函数式 API 包含三个基本方法:nnx.splitnnx.mergennx.update.

下面是 StatefulLinear nnx.Module 的示例,它使用了函数式 API。它包含

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)
(正在加载...)

StateGraphDef#

Flax nnx.Module 可以使用 nnx.split 函数分解为 nnx.Statennx.GraphDef

graphdef, state = nnx.split(model)

nnx.display(graphdef, state)
(正在加载...)
(正在加载...)

splitmergeupdate#

Flax 的 nnx.mergennx.split 的反向操作。它接受 nnx.GraphDef + nnx.State 并重建 nnx.Module。以下示例展示了这一点

  • 通过按顺序使用 nnx.splitnnx.merge,任何 Module 都可以被提升以用于任何 JAX 变换。

  • nnx.update 可以使用给定的 nnx.State 的内容原地更新对象。

  • 这种模式用于将状态从转换传播回外部的源对象。

print(f'{model.count.value = }')

# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
  # 2. Use `nnx.merge` to create a new model inside the JAX transformation.
  model = nnx.merge(graphdef, state)
  # 3. Call the `nnx.Module`
  y = model(x)
  # 4. Use `nnx.split` to propagate `nnx.State` updates.
  _, state = nnx.split(model)
  return y, state

y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)

print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)

这种模式的关键在于,在转换上下文中(包括基本的急切解释器)使用可变引用是可以的,但跨越边界时必须使用函数式 API。

为什么 Flax nnx.Module 不仅仅是 pytrees? 主要原因是,以这种方式很容易意外地丢失对共享引用的跟踪,例如,如果您将两个具有共享 Modulennx.Module 传递到 JAX 边界,您将默默地丢失该共享。Flax 的函数式 API 使这种行为显式,因此更容易推理。

细粒度 State 控制#

经验丰富的 Flax LinenHaiku API 用户可能认识到,将所有状态放在一个结构中并不总是最佳选择,因为在某些情况下您可能希望以不同的方式处理状态的不同子集。当与 JAX 转换 交互时,这是一种常见情况。

例如

  • 并非所有模型状态在与 jax.grad 交互时都可以或应该进行微分。

  • 或者,有时需要在使用 jax.lax.scan 时指定模型状态的哪一部分是承载,哪一部分不是。

为了解决这个问题,Flax NNX API 有 nnx.split,它允许您传递一个或多个 nnx.filterlib.Filters 来将 nnx.Variables 分成互斥的 nnx.States。Flax NNX 使用 Filter 在 API(如 nnx.splitnnx.state() 以及许多 NNX 转换)中创建 State 组。

下面的示例展示了最常见的 Filters

# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)

nnx.display(params, counts)
(正在加载...)
(正在加载...)

注意: nnx.filterlib.Filterss 必须是穷举的,如果未匹配到值,将引发错误。

正如预期的那样,nnx.mergennx.update 方法自然地使用多个 States

# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)