随机性#

与 Haiku/Flax Linen 等系统相比,Flax NNX 中的随机状态得到了彻底简化,它将“随机状态定义为对象状态”。本质上,这意味着随机状态只是另一种类型的状态,存储在变量中,并由模型本身进行管理。Flax NNX 中 RNG 系统的主要特点是其**显式**、**基于顺序**,并使用**动态计数器**。这与Flax Linen 的 RNG 系统略有不同,后者是 (路径 + 顺序) 驱动的,并使用静态计数器。

from flax import nnx
import jax
from jax import random, numpy as jnp

Rngs、RngStream 和 RngState#

Flax NNX 提供了 nnx.Rngs 类型作为管理随机状态的主要便捷 API。继 Flax Linen 的脚步,Rngs 能够创建多个命名 RNG 流,每个流都有自己的状态,用于在 JAX 转换的上下文中对随机性进行严格控制。以下是 Flax NNX 中主要与 RNG 相关的类型的细分

  • Rngs: 主要用户界面。它定义了一组命名 RngStream 对象。

  • nnx.RngStream: 一个可以生成 RNG 密钥流的对象。它在 RngKeyRngCount 变量中分别保存一个根 keycount。当生成新的密钥时,计数器会递增。

  • nnx.RngState: 所有与 RNG 相关的状态的基类型。

    • nnx.RngKey: 用于保存 RNG 密钥的变量类型,它包含一个 tag 属性,其中包含流的名称。

    • nnx.RngCount: 用于保存 RNG 计数的变量类型,它包含一个 tag 属性,其中包含流的名称。

要创建 Rngs 对象,可以将一个整数种子或 jax.random.key 实例简单地传递给构造函数中的任何关键字参数。以下是一个示例

rngs = nnx.Rngs(params=0, dropout=random.key(1))
nnx.display(rngs)

请注意,keycount 变量在 tag 属性中包含流名称。这主要用于过滤,我们将在后面看到。

要生成新的密钥,可以访问其中一个流并使用其 __call__ 方法,不带任何参数。这将使用当前 keycount 通过 random.fold_in 返回一个新的密钥。然后,count 会递增,以便后续调用会返回新的密钥。

params_key = rngs.params()
dropout_key = rngs.dropout()

nnx.display(rngs)

请注意,当生成新的密钥时,key 属性不会改变。

标准流名称#

Flax NNX 的内置层只使用两个标准流名称,如下表所示

名称

描述

params

用于参数初始化

dropout

用于 Dropout 来创建 dropout 掩码

params 用于大多数标准层(LinearConvMultiHeadAttention 等)在构造期间初始化其参数。 dropout 用于 DropoutMultiHeadAttention 生成 dropout 掩码。以下是一个使用 paramsdropout 流的模型示例

class Model(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.drop(self.linear(x)))

model = Model(nnx.Rngs(params=0, dropout=1))

y = model(x=jnp.ones((1, 20)))
print(f'{y.shape = }')
y.shape = (1, 10)

默认流#

使用命名流的缺点之一是用户需要知道创建 Rngs 对象时模型将使用的所有可能名称。虽然可以通过一些文档来解决这个问题,但 Flax NNX 提供了一个 default 流,当找不到流时可以使用它作为后备。要使用默认流,可以将一个整数种子或 jax.random.key 简单地作为第一个位置参数传递。

rngs = nnx.Rngs(0, params=1)

key1 = rngs.params() # call params
key2 = rngs.dropout() # fallback to default
key3 = rngs() # call default directly

# test with Model that uses params and dropout
model = Model(rngs)
y = model(jnp.ones((1, 20)))

nnx.display(rngs)

如上所示,也可以通过调用 Rngs 对象本身来生成来自 default 流的密钥。

注意
对于大型项目,建议使用命名流来避免潜在的冲突。对于小型项目或快速原型设计,只需使用 default 流是一个不错的选择。

过滤随机状态#

随机状态可以使用过滤器进行操作,就像任何其他类型的状态一样。它可以使用类型(RngStateRngKeyRngCount)或使用与流名称相对应的字符串进行过滤(参见过滤器 DSL)。以下是一个使用 nnx.state 和各种过滤器来选择 ModelRngs 的不同子状态的示例

model = Model(nnx.Rngs(params=0, dropout=1))

rng_state = nnx.state(model, nnx.RngState) # all random state
key_state = nnx.state(model, nnx.RngKey) # only keys
count_state = nnx.state(model, nnx.RngCount) # only counts
rng_params_state = nnx.state(model, 'params') # only params
rng_dropout_state = nnx.state(model, 'dropout') # only dropout
params_key_state = nnx.state(model, nnx.All('params', nnx.RngKey)) # params keys

nnx.display(params_key_state)

重新播种#

在 Haiku 和 Flax Linen 中,随机状态是在每次调用模型之前明确地传递给 apply。这使得在需要时(例如为了可重复性)轻松控制模型的随机性。在 Flax NNX 中,有两种方法可以解决这个问题

  1. 通过 __call__ 堆栈手动传递 Rngs 对象。标准层(如 DropoutMultiHeadAttention)接受一个 rngs 参数,以防您希望对随机状态进行严格控制。

  2. 使用 nnx.reseed 将模型的随机状态设置为特定配置。此选项侵入性较小,即使模型不是为启用手动控制随机状态而设计的,也可以使用。

reseed 是一个函数,它接受一个任意图节点(包括 Flax NNX 模块的 pytree)以及一些关键字参数,其中包含 RngStream 的新种子或密钥值,这些 RngStream 由参数名称指定。 reseed 随后会遍历图并更新匹配的 RngStream 的随机状态,这包括将 key 设置为可能的新值并将 count 重置为零。

以下是如何使用 reseed 重置 Dropout 层的随机状态并验证计算结果与第一次调用模型时是否相同的示例

model = Model(nnx.Rngs(params=0, dropout=1))
x = jnp.ones((1, 20))

y1 = model(x)
y2 = model(x)

nnx.reseed(model, dropout=1) # reset dropout RngState
y3 = model(x)

assert not jnp.allclose(y1, y2) # different
assert jnp.allclose(y1, y3)     # same

分割 Rngs#

在与 vmappmap 等转换交互时,通常需要分割随机状态,以便每个副本都有自己的唯一状态。这可以通过两种方式完成,要么是在将密钥传递给 Rngs 流之一之前手动分割密钥,要么使用 nnx.split_rngs 装饰器,该装饰器会自动分割函数输入中找到的任何 RngStream 的随机状态,并在函数调用结束后自动“降低”它们。 split_rngs 更方便,因为它可以与转换完美配合,因此我们将在此处显示一个示例

rngs = nnx.Rngs(params=0, dropout=1)

@nnx.split_rngs(splits=5, only='dropout')
def f(rngs: nnx.Rngs):
  print('Inside:')
  # rngs.dropout() # ValueError: fold_in accepts a single key...
  nnx.display(rngs)

f(rngs)

print('Outside:')
rngs.dropout() # works!
nnx.display(rngs)
Inside:
Outside:

请注意,split_rngs 允许将过滤器传递给 only 关键字参数以选择在函数内部应该分割的 RngStream,在本例中,我们只分割了 dropout 流。

转换#

如前所述,在 Flax NNX 中,随机状态只是另一种类型的状态,这意味着它在转换方面没有任何特殊之处。这意味着您应该能够使用每个转换的状态处理 API 来获得想要的结果。在本节中,我们将通过两个使用转换中的随机状态的示例,一个使用 pmap,我们将看到如何分割 RNG 状态,另一个使用 scan,我们将看到如何冻结 RNG 状态。

数据并行 dropout#

在第一个示例中,我们将探索如何使用 pmap 在数据并行环境中调用我们的 Model。由于 Model 使用了 Dropout,我们需要拆分 dropout 的随机状态,以确保每个副本获得不同的 dropout 掩码。 StateAxes 被传递给 in_axes,以指定 modeldropout 流将在轴 0 上并行化,而其其余状态将被复制。 split_rngs 用于将 dropout 流的键拆分为 N 个唯一的键,每个副本一个。

model = Model(nnx.Rngs(params=0, dropout=1))

num_devices = jax.local_device_count()
x = jnp.ones((num_devices, 16, 20))
state_axes = nnx.StateAxes({'dropout': 0, ...: None})

@nnx.split_rngs(splits=num_devices, only='dropout')
@nnx.pmap(in_axes=(state_axes, 0), out_axes=0)
def forward(model: Model, x: jnp.ndarray):
  return model(x)

y = forward(model, x)
print(y.shape)
(1, 16, 10)

循环 dropout#

接下来我们将探索如何实现一个使用循环 dropout 的 RNNCell。为此,我们将简单地创建一个 Dropout 层,该层将从自定义的 recurrent_dropout 流中采样键,并将对 RNNCell 的隐藏状态 h 应用 dropout。将定义一个 initial_state 方法来创建 RNNCell 的初始状态。

class Count(nnx.Variable): pass

class RNNCell(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
    self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
    self.dout = dout
    self.count = Count(jnp.array(0, jnp.uint32))

  def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:
    h = self.drop(h) # recurrent dropout
    y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))
    self.count += 1
    return y, y

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.dout))

cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))

接下来,我们将使用 scanunroll 函数上实现 rnn_forward 操作。循环 dropout 的关键要素是在所有时间步长上应用相同的 dropout 掩码,为此,我们将 StateAxes 传递给 scanin_axes,指定 cellrecurrent_dropout 流将被广播,并且 cell 的其余状态将被保留。此外,隐藏状态 h 将是 scanCarry 变量,序列 x 将在其轴 1 上被扫描。

@nnx.jit
def rnn_forward(cell: RNNCell, x: jax.Array):
  h = cell.initial_state(batch_size=x.shape[0])

  # broadcast 'recurrent_dropout' RNG state to have the same mask on every step
  state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})
  @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))
  def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:
    h, y = cell(h, x)
    return h, y

  h, y = unroll(cell, h, x)
  return y

x = jnp.ones((4, 20, 8))
y = rnn_forward(cell, x)

print(f'{y.shape = }')
print(f'{cell.count.value = }')
y.shape = (4, 20, 16)
cell.count.value = Array(20, dtype=uint32)