亚麻基础#
Flax NNX 是一个新的简化 API,旨在简化在 JAX 中创建、检查、调试和分析神经网络。它通过添加对 Python 引用语义的一流支持来实现这一点。这允许用户使用常规 Python 对象来表达他们的模型,这些对象被建模为 PyGraphs(而不是 pytrees),从而实现引用共享和可变性。这种 API 设计应该让 PyTorch 或 Keras 用户感到宾至如归。
在本指南中,您将了解
Flax
nnx.Module
系统:创建和初始化自定义Linear
层的示例。有状态计算:创建 Flax
nnx.Variable
并更新其值(例如,正向传递期间所需的 state 更新)的示例。嵌套的
nnx.Module
:包含Linear
、nnx.Dropout
和nnx.BatchNorm
层的 MLP 示例。模型手术:将模型内的自定义
Linear
层替换为自定义LoraLinear
层的示例。
Flax 转换:使用
nnx.jit
进行自动状态管理的示例。nnx.scan
在层上。
Flax NNX 函数式 API:具有
nnx.Param
的自定义StatefulLinear
层的示例,可以对 state 进行细粒度控制。
设置#
使用 pip
安装 Flax 并赋予必要的依赖项
# ! pip install -U flax treescope
from flax import nnx
import jax
import jax.numpy as jnp
Flax nnx.Module
系统#
Flax nnx.Module
与 Flax Linen 或 Haiku 中的其他 Module
系统之间的主要区别在于,在 NNX 中,一切都 明确。这意味着,除其他外,
nnx.Module
本身直接保存 state(例如参数)。用户通过线程处理 PRNG state。
必须在初始化时提供所有形状信息(无形状推断)。
让我们从创建一个 Linear
nnx.Module
开始。以下代码显示了
动态 state 通常存储在
nnx.Param
中,而静态 state(NNX 未处理的所有类型),例如整数或字符串,则直接存储。类型为
jax.Array
和numpy.ndarray
的属性也被视为动态 state,尽管将它们存储在nnx.Variable
中(例如Param
)更可取。nnx.Rngs
对象可用于根据传递给构造函数的根 PRNG 密钥获取新的唯一密钥。
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
另请注意
可以使用
value
属性访问nnx.Variable
的内部值,但为了方便起见,它们实现了所有数值运算符,并且可以直接在算术表达式中使用(如上面的代码所示)。
要初始化 Flax nnx.Module
,只需调用构造函数,Module
的所有参数通常都会急切创建。由于 nnx.Module
保存自己的 state 方法,因此您可以直接调用它们,而无需单独的 apply
方法。这对于调试非常方便,允许您直接检查模型的整个结构。
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))
print(y)
nnx.display(model)
[[1.245453 0.74195766 0.8553282 0.6763327 1.2617068 ]]
上面的 nnx.display
可视化是使用出色的 Treescope 库生成的。
有状态计算#
实现层(例如 nnx.BatchNorm
)需要在正向传递期间执行 state 更新。在 Flax NNX 中,您只需要创建一个 nnx.Variable
,并在正向传递期间更新其 .value
。
class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
在 JAX 中通常避免使用可变引用。但 Flax NNX 提供了健全的机制来处理它们,如本指南后面部分所示。
嵌套的 nnx.Module
#
Flax nnx.Module
可用于以嵌套结构组合其他 Module
。这些可以作为属性直接分配,或者在任何(嵌套的)pytree 类型的属性内分配,例如 list
、dict
、tuple
等。
以下示例展示了如何通过子类化 nnx.Module
来定义一个简单的 MLP
。该模型包含两个 Linear
层、一个 nnx.Dropout
层和一个 nnx.BatchNorm
层
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
在 Flax 中,nnx.Dropout
是一个有状态的模块,它存储一个 nnx.Rngs
对象,以便它可以在前向传递过程中生成新的掩码,而无需用户每次都传递新的密钥。
模型手术#
Flax nnx.Module
默认情况下是可变的。这意味着它们的结构可以在任何时候更改,这使得模型手术非常容易,因为任何子 Module
属性都可以被其他任何东西替换,例如新的 Module
、现有的共享 Module
、不同类型的 Module
等等。此外,nnx.Variable
也可以被修改或替换/共享。
以下示例展示了如何将之前示例中 MLP
模型中的 Linear
层替换为 LoraLinear
层
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
self.linear = linear
self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
Flax 变换#
Flax NNX 变换 (transforms) 扩展了 JAX 变换 以支持 nnx.Module
和其他对象。它们作为其等效 JAX 对应物的超集,增加了对对象状态的感知,并提供了用于变换它的额外 API。
Flax 变换的主要特性之一是保留了引用语义,这意味着只要在变换规则内合法,在变换内部发生的任何对象图变异都会传播到外部。在实践中,这意味着 Flax 程序可以使用命令式代码来表达,从而极大地简化了用户体验。
在以下示例中,您定义了一个 train_step
函数,它接受一个 MLP
模型、一个 nnx.Optimizer
和一批数据,并返回该步的损失。损失和梯度是使用 nnx.value_and_grad
变换在 loss_fn
上计算的。梯度被传递给优化器的 nnx.Optimizer.update
方法,以更新 model
的参数。
import optax
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # Automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # In place updates.
return loss
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)
print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000279, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)
这个示例中有两件事值得一提
对每个
nnx.BatchNorm
和nnx.Dropout
层状态的更新会自动从loss_fn
内部传播到train_step
,一直传播到外部的model
引用。optimizer
持有一个对model
的可变引用 - 这种关系在train_step
函数内部被保留,这使得可以使用优化器单独更新模型的参数成为可能。
nnx.scan
在层上#
下一个示例使用 Flax nnx.vmap
创建多个 MLP 层的堆栈,并使用 nnx.scan
迭代地将堆栈中的每一层应用于输入。
在下面的代码中注意以下几点
自定义的
create_model
函数接受一个键并返回一个MLP
对象,因为您创建了五个键并使用nnx.vmap
在create_model
上,创建了一个由 5 个MLP
对象组成的堆栈。The
nnx.scan
用于迭代地将堆栈中的每个MLP
应用于输入x
。The
nnx.scan
(有意地) 偏离了jax.lax.scan
,而是模仿了nnx.vmap
,后者更具表现力。nnx.scan
允许指定多个输入、每个输入/输出的扫描轴以及进位的位置。The
nnx.BatchNorm
和nnx.Dropout
层的状态更新会自动由nnx.scan
传播。
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x
x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
Flax NNX 变换如何实现这一点?为了理解 Flax NNX 对象如何与 JAX 变换交互,下一节解释了 Flax NNX 函数式 API。
Flax 函数式 API#
Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了清晰的界限。它还允许对 Flax Linen 和 Haiku 用户所习惯的状态进行相同程度的细粒度控制。Flax NNX 函数式 API 包含三个基本方法:nnx.split
、nnx.merge
和 nnx.update
.
下面是 StatefulLinear
nnx.Module
的示例,它使用了函数式 API。它包含
一些
nnx.Param
nnx.Variable
;以及一个自定义的
Count()
nnx.Variable
类型,它用于跟踪每次前向传递时都会增加的整数标量状态。
class Count(nnx.Variable): pass
class StatefulLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(jnp.array(0, dtype=jnp.uint32))
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))
nnx.display(model)
State
和 GraphDef
#
Flax nnx.Module
可以使用 nnx.split
函数分解为 nnx.State
和 nnx.GraphDef
nnx.State
是一个从字符串到nnx.Variable
或嵌套的State
的Mapping
。nnx.GraphDef
包含重建nnx.Module
图所需的所有静态信息,它类似于 JAX 的PyTreeDef
.
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
split
、merge
和 update
#
Flax 的 nnx.merge
是 nnx.split
的反向操作。它接受 nnx.GraphDef
+ nnx.State
并重建 nnx.Module
。以下示例展示了这一点
nnx.update
可以使用给定的nnx.State
的内容原地更新对象。这种模式用于将状态从转换传播回外部的源对象。
print(f'{model.count.value = }')
# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
# 2. Use `nnx.merge` to create a new model inside the JAX transformation.
model = nnx.merge(graphdef, state)
# 3. Call the `nnx.Module`
y = model(x)
# 4. Use `nnx.split` to propagate `nnx.State` updates.
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)
print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)
这种模式的关键在于,在转换上下文中(包括基本的急切解释器)使用可变引用是可以的,但跨越边界时必须使用函数式 API。
为什么 Flax nnx.Module
不仅仅是 pytrees? 主要原因是,以这种方式很容易意外地丢失对共享引用的跟踪,例如,如果您将两个具有共享 Module
的 nnx.Module
传递到 JAX 边界,您将默默地丢失该共享。Flax 的函数式 API 使这种行为显式,因此更容易推理。
细粒度 State
控制#
经验丰富的 Flax Linen 或 Haiku API 用户可能认识到,将所有状态放在一个结构中并不总是最佳选择,因为在某些情况下您可能希望以不同的方式处理状态的不同子集。当与 JAX 转换 交互时,这是一种常见情况。
例如
并非所有模型状态在与
jax.grad
交互时都可以或应该进行微分。或者,有时需要在使用
jax.lax.scan
时指定模型状态的哪一部分是承载,哪一部分不是。
为了解决这个问题,Flax NNX API 有 nnx.split
,它允许您传递一个或多个 nnx.filterlib.Filter
s 来将 nnx.Variable
s 分成互斥的 nnx.State
s。Flax NNX 使用 Filter
在 API(如 nnx.split
、nnx.state()
以及许多 NNX 转换)中创建 State
组。
下面的示例展示了最常见的 Filter
s
# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
注意: nnx.filterlib.Filter
ss 必须是穷举的,如果未匹配到值,将引发错误。
正如预期的那样,nnx.merge
和 nnx.update
方法自然地使用多个 State
s
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)