词汇表

词汇表#

有关其他术语,请参考 JAX 词汇表

过滤器#

一种从 模块 中提取特定 变量 的方法。通常通过在模块上调用 nnx.split 来实现。查看 过滤器指南 了解更多信息。

折叠#

给定输入 PRNG 密钥和整数生成新的 PRNG 密钥。通常在您想要生成新的密钥但仍能使用原始 rng 密钥的情况下使用。您也可以使用 jax.random.split 来实现这一点,但这实际上会创建两个 RNG 密钥,速度较慢。请参阅我们的 RNG 指南,了解 Flax 如何自动生成新的 PRNG 密钥。

GraphDef#

nnx.GraphDef,一个表示 nnx.Module 定义的所有静态、无状态、Pythonic 部分的类。

提升转换#

JAX 转换 的包装版本,允许转换后的函数接收 Flax 模块 作为输入或输出。例如,jax.jit 的提升版本将是 flax.nnx.jit。查看 提升转换指南

合并#

查看 拆分和合并

模块#

nnx.Module,一个数据类,允许以引用透明的方式定义和初始化参数。它负责存储和更新自身中的变量和参数。

参数#

nnx.Paramnnx.Variable 的一个特定子类,通常包含可训练的权重。

RNG 状态#

Flax 模块 可以保留对 RNG 状态 对象 的引用,该对象可以生成新的 JAX PRNG 密钥。这些密钥用于通过 JAX 的函数式随机数生成器 生成随机 JAX 数组。您可以使用具有不同种子的 RNG 状态来对模型进行更细粒度的控制(例如,对参数和 dropout 掩码使用独立的随机数)。查看 RNG 指南 了解更多信息。

拆分和合并#

nnx.split,一种通过两部分来表示 nnx.Module 的方法 - 一个捕获其 Pythonic、静态信息的静态 GraphDef,以及一个或多个捕获其 JAX 数组的 变量状态(以 pytrees 的形式)。它们可以通过 nnx.merge 合并回原始模块。

变量#

Flax 模块 中的 权重 / 参数 / 数据 / 数组。变量在模块中定义为 nnx.Variable 或其子类。

变量状态#

nnx.VariableState模块 中所有 变量 的纯函数式 pytree。由于它是纯函数式,因此它可以作为 JAX 转换函数的输入或输出。可以通过 拆分 模块来获取。