词汇表#
有关其他术语,请参考 JAX 词汇表。
- 过滤器#
一种从 模块 中提取特定 变量 的方法。通常通过在模块上调用
nnx.split
来实现。查看 过滤器指南 了解更多信息。- 折叠#
给定输入 PRNG 密钥和整数生成新的 PRNG 密钥。通常在您想要生成新的密钥但仍能使用原始 rng 密钥的情况下使用。您也可以使用 jax.random.split 来实现这一点,但这实际上会创建两个 RNG 密钥,速度较慢。请参阅我们的 RNG 指南,了解 Flax 如何自动生成新的 PRNG 密钥。
- GraphDef#
nnx.GraphDef
,一个表示nnx.Module
定义的所有静态、无状态、Pythonic 部分的类。- 提升转换#
对 JAX 转换 的包装版本,允许转换后的函数接收 Flax 模块 作为输入或输出。例如,jax.jit 的提升版本将是
flax.nnx.jit
。查看 提升转换指南。- 合并#
查看 拆分和合并。
- 模块#
nnx.Module
,一个数据类,允许以引用透明的方式定义和初始化参数。它负责存储和更新自身中的变量和参数。- 参数#
nnx.Param
,nnx.Variable
的一个特定子类,通常包含可训练的权重。- RNG 状态#
Flax
模块
可以保留对RNG 状态 对象
的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥用于通过 JAX 的函数式随机数生成器 生成随机 JAX 数组。您可以使用具有不同种子的 RNG 状态来对模型进行更细粒度的控制(例如,对参数和 dropout 掩码使用独立的随机数)。查看 RNG 指南 了解更多信息。- 拆分和合并#
nnx.split
,一种通过两部分来表示 nnx.Module 的方法 - 一个捕获其 Pythonic、静态信息的静态 GraphDef,以及一个或多个捕获其 JAX 数组的 变量状态(以 pytrees 的形式)。它们可以通过nnx.merge
合并回原始模块。- 变量#
Flax 模块 中的 权重 / 参数 / 数据 / 数组。变量在模块中定义为
nnx.Variable
或其子类。- 变量状态#
nnx.VariableState
,模块 中所有 变量 的纯函数式 pytree。由于它是纯函数式,因此它可以作为 JAX 转换函数的输入或输出。可以通过 拆分 模块来获取。