变换#
- flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
可处理模块/图节点作为参数的
jax.grad
的提升版本。每个图节点的可微状态由 wrt 过滤器定义,默认情况下设置为 nnx.Param。在内部,会提取图节点的
State
,根据 wrt 过滤器进行过滤,然后传递给底层的jax.grad
函数。图节点的梯度类型为State
。示例
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- 参数
fun – 要进行微分的函数。它在由
argnums
指定的位置上的参数应为数组、标量、图节点或标准 Python 容器。由argnums
指定的位置上的参数数组必须为非精确(即浮点或复数)类型。它应返回一个标量(包括形状为()
的数组,但不包括形状为(1,)
等的数组)。argnums – 可选,整数或整数序列。指定要针对其进行微分的哪个位置参数(默认值为 0)。
has_aux – 可选,布尔值。指示
fun
是否返回一对,其中第一个元素被视为要进行微分的数学函数的输出,而第二个元素是辅助数据。默认值为 False。holomorphic – 可选,布尔值。指示
fun
是否承诺为全纯函数。如果为 True,则输入和输出必须为复数。默认值为 False。allow_int – 可选,布尔值。是否允许针对整数值输入进行微分。整数输入的梯度将具有一个简单的向量空间类型 (float0)。默认值为 False。
reduce_axes – 可选,轴名称元组。如果某个轴在此处列出,并且
fun
隐式地在该轴上广播一个值,则反向传递将对相应的梯度执行psum
。否则,梯度将针对命名轴上的每个示例。例如,如果'batch'
是一个命名的批处理轴,则grad(f, reduce_axes=('batch',))
将创建一个计算总梯度的函数,而grad(f)
将创建一个计算每个示例梯度的函数。
- flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#
可处理模块/图节点作为参数的
jax.jit
的提升版本。- 参数
fun –
要进行 jit 处理的函数。
fun
应该是一个纯函数,因为副作用可能只执行一次。fun
的参数和返回值应为数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)。由static_argnums
指定的位置参数可以是任何内容,只要它们是可散列的并且定义了相等操作。静态参数作为编译缓存键的一部分包含在内,因此必须定义散列和相等操作符。JAX 保持对
fun
的弱引用,以用作编译缓存键,因此对象fun
必须是弱可引用的。大多数Callable
对象将已经满足此要求。in_shardings –
与
fun
参数结构匹配的 Pytree,所有实际参数都被资源分配规范替换。也可以指定 Pytree 前缀(例如,整个子树中的一个值),在这种情况下,叶节点将广播到该子树中的所有值。in_shardings
参数是可选的。JAX 将从输入jax.Array
推断分片,如果无法推断,则默认情况下将复制输入。- 有效的资源分配规范是
Sharding
,它将决定如何对值进行分区。将被分区。使用此方法,无需使用网格上下文管理器。
None
,将赋予 JAX 选择任何它想要的分片的自由。对于 in_shardings,JAX 将将其标记为已复制,但此行为将来可能会改变。对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。
每个维度的尺寸必须是分配给它的资源总数的倍数。这类似于 pjit 的 in_shardings。
out_shardings –
与
in_shardings
相似,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。out_shardings
参数是可选的。如果未指定,则jax.jit()
将使用 GSPMD 的分片传播来确定输出的分片应是什么。static_argnums –
一个可选的整数或整数集合,用于指定哪些位置参数应被视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。
静态参数应该是可哈希的,这意味着
__hash__
和__eq__
都已实现,并且是不可变的。 使用这些常量的不同值调用 jitted 函数将触发重新编译。 必须将不是数组或其容器的参数标记为静态。如果既没有提供
static_argnums
也没有提供static_argnames
,则没有参数被视为静态。 如果未提供static_argnums
但提供了static_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找任何与static_argnames
对应的定位参数(反之亦然)。 如果同时提供static_argnums
和static_argnames
,则不会使用inspect.signature
,并且只有在static_argnums
或static_argnames
中列出的实际参数将被视为静态。static_argnames – 一个可选的字符串或字符串集合,指定要将哪些命名参数视为静态(编译时常量)。 有关详细信息,请参阅
static_argnums
的注释。 如果未提供,但设置了static_argnums
,则默认值基于调用inspect.signature(fun)
来查找对应的命名参数。donate_argnums –
指定哪些定位参数缓冲区“捐赠”给计算。 如果计算完成后不再需要参数缓冲区,则捐赠它们是安全的。 在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将您的一个输入缓冲区循环利用来存储结果。 您不应该重用捐赠给计算的缓冲区,如果您尝试重用,JAX 会引发错误。 默认情况下,不会捐赠任何参数缓冲区。
如果既没有提供
donate_argnums
也没有提供donate_argnames
,则没有参数被捐赠。 如果未提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 会使用inspect.signature(fun)
来查找任何与donate_argnames
对应的定位参数(反之亦然)。 如果同时提供donate_argnums
和donate_argnames
,则不会使用inspect.signature
,并且只有在donate_argnums
或donate_argnames
中列出的实际参数将被捐赠。有关缓冲区捐赠的更多详细信息,请参阅 FAQ。
donate_argnames – 一个可选的字符串或字符串集合,指定哪些命名参数被捐赠给计算。 有关详细信息,请参阅
donate_argnums
的注释。 如果未提供,但设置了donate_argnums
,则默认值基于调用inspect.signature(fun)
来查找对应的命名参数。keep_unused – 如果为 False(默认值),JAX 确定为 fun 未使用的参数 *可能* 会从生成的已编译 XLA 可执行文件中删除。 此类参数不会被传输到设备,也不会被提供给底层可执行文件。 如果为 True,则不会修剪未使用的参数。
device – 这是一个实验性功能,API 可能会发生变化。 可选的,jitted 函数将运行的设备。(可以通过
jax.devices()
获取可用设备。) 默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend – 这是一个实验性功能,API 可能会发生变化。 可选的,一个字符串,表示 XLA 后端:
'cpu'
、'gpu'
或'tpu'
。inline – 指定此函数是否应该内联到封闭的 jaxprs 中(而不是表示为具有其自身子 jaxpr 的 xla_call 原语的应用)。 默认值为 False。
- 返回值
一个包装后的
fun
版本,设置为进行即时编译。
- flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[source]#
- flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[source]#
- flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
- flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[source]#
jax.vmap 的引用感知版本。
- 参数
f – 要在其他轴上映射的函数。
in_axes – 一个整数、None 或值序列,指定要映射哪些输入数组轴(见 jax.vmap)。 除了整数和 None 之外,还可以使用
StateAxes
来控制图形节点(如模块)如何通过指定要应用于给定 Filter 的图形节点子状态的轴来进行向量化。out_axes – 一个整数、None 或 pytree,指示映射的轴应出现在输出中的位置(见 jax.vmap)。
axis_name – 可选的,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合。
axis_size – 可选的,一个整数,指示要映射的轴的大小。 如果未提供,则映射的轴大小将从参数中推断。
- 返回值
具有与
f
对应的参数的f
的批处理/向量化版本,但在由in_axes
指示的位置具有额外的数组轴,以及与f
对应的返回值,但在由out_axes
指示的位置具有额外的数组轴。
示例
>>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3)
>>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return jnp.dot(x, model.w.value) ... >>> y = forward(model, x) >>> y.shape (5, 3)
为了控制图形节点子状态如何进行向量化,可以将
StateAxes
传递给in_axes
和out_axes
,指定要应用于给定过滤器的每个子状态的轴。 以下示例显示了如何在保持不同批处理统计信息和 dropout 随机状态的同时,在集成成员之间共享参数>>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32)