随机性#
与 Haiku/Flax Linen 等系统相比,Flax NNX 中的随机状态得到了彻底简化,它将“随机状态定义为对象状态”。本质上,这意味着随机状态只是另一种类型的状态,存储在变量中,并由模型本身进行管理。Flax NNX 中 RNG 系统的主要特点是其**显式**、**基于顺序**,并使用**动态计数器**。这与Flax Linen 的 RNG 系统略有不同,后者是 (路径 + 顺序) 驱动的,并使用静态计数器。
from flax import nnx
import jax
from jax import random, numpy as jnp
Rngs、RngStream 和 RngState#
Flax NNX 提供了 nnx.Rngs
类型作为管理随机状态的主要便捷 API。继 Flax Linen 的脚步,Rngs
能够创建多个命名 RNG 流,每个流都有自己的状态,用于在 JAX 转换的上下文中对随机性进行严格控制。以下是 Flax NNX 中主要与 RNG 相关的类型的细分
Rngs: 主要用户界面。它定义了一组命名
RngStream
对象。nnx.RngStream: 一个可以生成 RNG 密钥流的对象。它在
RngKey
和RngCount
变量中分别保存一个根key
和count
。当生成新的密钥时,计数器会递增。nnx.RngState: 所有与 RNG 相关的状态的基类型。
nnx.RngKey: 用于保存 RNG 密钥的变量类型,它包含一个
tag
属性,其中包含流的名称。nnx.RngCount: 用于保存 RNG 计数的变量类型,它包含一个
tag
属性,其中包含流的名称。
要创建 Rngs
对象,可以将一个整数种子或 jax.random.key
实例简单地传递给构造函数中的任何关键字参数。以下是一个示例
rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)
请注意,key
和 count
变量在 tag
属性中包含流名称。这主要用于过滤,我们将在后面看到。
要生成新的密钥,可以访问其中一个流并使用其 __call__
方法,不带任何参数。这将使用当前 key
和 count
通过 random.fold_in
返回一个新的密钥。然后,count
会递增,以便后续调用会返回新的密钥。
params_key = rngs.params()
dropout_key = rngs.dropout()
nnx.display(rngs)
请注意,当生成新的密钥时,key
属性不会改变。
标准流名称#
Flax NNX 的内置层只使用两个标准流名称,如下表所示
名称 |
描述 |
---|---|
|
用于参数初始化 |
|
用于 |
params
用于大多数标准层(Linear
、Conv
、MultiHeadAttention
等)在构造期间初始化其参数。 dropout
用于 Dropout
和 MultiHeadAttention
生成 dropout 掩码。以下是一个使用 params
和 dropout
流的模型示例
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
model = Model(nnx.Rngs(params=0, dropout=1))
y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)
默认流#
使用命名流的缺点之一是用户需要知道创建 Rngs
对象时模型将使用的所有可能名称。虽然可以通过一些文档来解决这个问题,但 Flax NNX 提供了一个 default
流,当找不到流时可以使用它作为后备。要使用默认流,可以将一个整数种子或 jax.random.key
简单地作为第一个位置参数传递。
rngs = nnx.Rngs(0, params=1)
key1 = rngs.params() # call params
key2 = rngs.dropout() # fallback to default
key3 = rngs() # call default directly
# test with Model that uses params and dropout
model = Model(rngs)
y = model(jnp.ones((1, 20)))
nnx.display(rngs)
如上所示,也可以通过调用 Rngs
对象本身来生成来自 default
流的密钥。
注意
对于大型项目,建议使用命名流来避免潜在的冲突。对于小型项目或快速原型设计,只需使用default
流是一个不错的选择。
过滤随机状态#
随机状态可以使用过滤器进行操作,就像任何其他类型的状态一样。它可以使用类型(RngState
、RngKey
、RngCount
)或使用与流名称相对应的字符串进行过滤(参见过滤器 DSL)。以下是一个使用 nnx.state
和各种过滤器来选择 Model
中 Rngs
的不同子状态的示例
model = Model(nnx.Rngs(params=0, dropout=1))
rng_state = nnx.state(model, nnx.RngState) # all random state
key_state = nnx.state(model, nnx.RngKey) # only keys
count_state = nnx.state(model, nnx.RngCount) # only counts
rng_params_state = nnx.state(model, 'params') # only params
rng_dropout_state = nnx.state(model, 'dropout') # only dropout
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys
nnx.display(params_key_state)
重新播种#
在 Haiku 和 Flax Linen 中,随机状态是在每次调用模型之前明确地传递给 apply
。这使得在需要时(例如为了可重复性)轻松控制模型的随机性。在 Flax NNX 中,有两种方法可以解决这个问题
通过
__call__
堆栈手动传递Rngs
对象。标准层(如Dropout
和MultiHeadAttention
)接受一个rngs
参数,以防您希望对随机状态进行严格控制。使用
nnx.reseed
将模型的随机状态设置为特定配置。此选项侵入性较小,即使模型不是为启用手动控制随机状态而设计的,也可以使用。
reseed
是一个函数,它接受一个任意图节点(包括 Flax NNX 模块的 pytree)以及一些关键字参数,其中包含 RngStream
的新种子或密钥值,这些 RngStream
由参数名称指定。 reseed
随后会遍历图并更新匹配的 RngStream
的随机状态,这包括将 key
设置为可能的新值并将 count
重置为零。
以下是如何使用 reseed
重置 Dropout
层的随机状态并验证计算结果与第一次调用模型时是否相同的示例
model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))
y1 = model(x)
y2 = model(x)
nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)
assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3) # same
分割 Rngs#
在与 vmap
或 pmap
等转换交互时,通常需要分割随机状态,以便每个副本都有自己的唯一状态。这可以通过两种方式完成,要么是在将密钥传递给 Rngs
流之一之前手动分割密钥,要么使用 nnx.split_rngs
装饰器,该装饰器会自动分割函数输入中找到的任何 RngStream
的随机状态,并在函数调用结束后自动“降低”它们。 split_rngs
更方便,因为它可以与转换完美配合,因此我们将在此处显示一个示例
rngs = nnx.Rngs(params=0, dropout=1)
@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
print('Inside:')
# rngs.dropout() # ValueError: fold_in accepts a single key...
nnx.display(rngs)
f(rngs)
print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:
请注意,split_rngs
允许将过滤器传递给 only
关键字参数以选择在函数内部应该分割的 RngStream
,在本例中,我们只分割了 dropout
流。
转换#
如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态,这意味着它在转换方面没有任何特殊之处。这意味着您应该能够使用每个转换的状态处理 API 来获得想要的结果。在本节中,我们将通过两个使用转换中的随机状态的示例,一个使用 pmap
,我们将看到如何分割 RNG 状态,另一个使用 scan
,我们将看到如何冻结 RNG 状态。
数据并行 dropout#
在第一个示例中,我们将探索如何使用 pmap
在数据并行环境中调用我们的 Model
。由于 Model 使用了 Dropout
,我们需要拆分 dropout
的随机状态,以确保每个副本获得不同的 dropout 掩码。 StateAxes
被传递给 in_axes
,以指定 model
的 dropout
流将在轴 0
上并行化,而其其余状态将被复制。 split_rngs
用于将 dropout
流的键拆分为 N 个唯一的键,每个副本一个。
model = Model(nnx.Rngs(params=0, dropout=1))
num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
return model(x)
y = forward(model, x)
print(y.shape)
(1, 16, 10)
循环 dropout#
接下来我们将探索如何实现一个使用循环 dropout 的 RNNCell。为此,我们将简单地创建一个 Dropout
层,该层将从自定义的 recurrent_dropout
流中采样键,并将对 RNNCell 的隐藏状态 h
应用 dropout。将定义一个 initial_state
方法来创建 RNNCell 的初始状态。
class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))
def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
h = self.drop(h) # recurrent dropout
y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
self.count += 1
return y, y
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.dout))
cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))
接下来,我们将使用 scan
在 unroll
函数上实现 rnn_forward
操作。循环 dropout 的关键要素是在所有时间步长上应用相同的 dropout 掩码,为此,我们将 StateAxes
传递给 scan
的 in_axes
,指定 cell
的 recurrent_dropout
流将被广播,并且 cell
的其余状态将被保留。此外,隐藏状态 h
将是 scan
的 Carry
变量,序列 x
将在其轴 1
上被扫描。
@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
h = cell.initial_state(batch_size=x.shape[0])
# broadcast 'recurrent_dropout' RNG state to have the same mask on every step
state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
@nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
h, y = cell(h, x)
return h, y
h, y = unroll(cell, h, x)
return y
x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)
print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)