变换#

class flax.nnx.Jit(*args, **kwargs)[source]#
class flax.nnx.Remat(*args, **kwargs)[source]#
class flax.nnx.Scan(*args, **kwargs)[source]#
class flax.nnx.Vmap(*args, **kwargs)[source]#
flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#

可处理模块/图节点作为参数的 jax.grad 的提升版本。

每个图节点的可微状态由 wrt 过滤器定义,默认情况下设置为 nnx.Param。在内部,会提取图节点的 State,根据 wrt 过滤器进行过滤,然后传递给底层的 jax.grad 函数。图节点的梯度类型为 State

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})
参数
  • fun – 要进行微分的函数。它在由 argnums 指定的位置上的参数应为数组、标量、图节点或标准 Python 容器。由 argnums 指定的位置上的参数数组必须为非精确(即浮点或复数)类型。它应返回一个标量(包括形状为 () 的数组,但不包括形状为 (1,) 等的数组)。

  • argnums – 可选,整数或整数序列。指定要针对其进行微分的哪个位置参数(默认值为 0)。

  • has_aux – 可选,布尔值。指示 fun 是否返回一对,其中第一个元素被视为要进行微分的数学函数的输出,而第二个元素是辅助数据。默认值为 False。

  • holomorphic – 可选,布尔值。指示 fun 是否承诺为全纯函数。如果为 True,则输入和输出必须为复数。默认值为 False。

  • allow_int – 可选,布尔值。是否允许针对整数值输入进行微分。整数输入的梯度将具有一个简单的向量空间类型 (float0)。默认值为 False。

  • reduce_axes – 可选,轴名称元组。如果某个轴在此处列出,并且 fun 隐式地在该轴上广播一个值,则反向传递将对相应的梯度执行 psum。否则,梯度将针对命名轴上的每个示例。例如,如果 'batch' 是一个命名的批处理轴,则 grad(f, reduce_axes=('batch',)) 将创建一个计算总梯度的函数,而 grad(f) 将创建一个计算每个示例梯度的函数。

flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#

可处理模块/图节点作为参数的 jax.jit 的提升版本。

参数
  • fun

    要进行 jit 处理的函数。 fun 应该是一个纯函数,因为副作用可能只执行一次。

    fun 的参数和返回值应为数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)。由 static_argnums 指定的位置参数可以是任何内容,只要它们是可散列的并且定义了相等操作。静态参数作为编译缓存键的一部分包含在内,因此必须定义散列和相等操作符。

    JAX 保持对 fun 的弱引用,以用作编译缓存键,因此对象 fun 必须是弱可引用的。大多数 Callable 对象将已经满足此要求。

  • in_shardings

    fun 参数结构匹配的 Pytree,所有实际参数都被资源分配规范替换。也可以指定 Pytree 前缀(例如,整个子树中的一个值),在这种情况下,叶节点将广播到该子树中的所有值。

    in_shardings 参数是可选的。JAX 将从输入 jax.Array 推断分片,如果无法推断,则默认情况下将复制输入。

    有效的资源分配规范是
    • Sharding,它将决定如何对值进行分区。

      将被分区。使用此方法,无需使用网格上下文管理器。

    • None,将赋予 JAX 选择任何它想要的分片的自由。对于 in_shardings,JAX 将将其标记为已复制,但此行为将来可能会改变。对于 out_shardings,我们将依靠 XLA GSPMD 分区器来确定输出分片。

    每个维度的尺寸必须是分配给它的资源总数的倍数。这类似于 pjit 的 in_shardings。

  • out_shardings

    in_shardings 相似,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。

    out_shardings 参数是可选的。如果未指定,则 jax.jit() 将使用 GSPMD 的分片传播来确定输出的分片应是什么。

  • static_argnums

    一个可选的整数或整数集合,用于指定哪些位置参数应被视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。

    静态参数应该是可哈希的,这意味着 __hash____eq__ 都已实现,并且是不可变的。 使用这些常量的不同值调用 jitted 函数将触发重新编译。 必须将不是数组或其容器的参数标记为静态。

    如果既没有提供 static_argnums 也没有提供 static_argnames,则没有参数被视为静态。 如果未提供 static_argnums 但提供了 static_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找任何与 static_argnames 对应的定位参数(反之亦然)。 如果同时提供 static_argnumsstatic_argnames,则不会使用 inspect.signature,并且只有在 static_argnumsstatic_argnames 中列出的实际参数将被视为静态。

  • static_argnames – 一个可选的字符串或字符串集合,指定要将哪些命名参数视为静态(编译时常量)。 有关详细信息,请参阅 static_argnums 的注释。 如果未提供,但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找对应的命名参数。

  • donate_argnums

    指定哪些定位参数缓冲区“捐赠”给计算。 如果计算完成后不再需要参数缓冲区,则捐赠它们是安全的。 在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将您的一个输入缓冲区循环利用来存储结果。 您不应该重用捐赠给计算的缓冲区,如果您尝试重用,JAX 会引发错误。 默认情况下,不会捐赠任何参数缓冲区。

    如果既没有提供 donate_argnums 也没有提供 donate_argnames,则没有参数被捐赠。 如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找任何与 donate_argnames 对应的定位参数(反之亦然)。 如果同时提供 donate_argnumsdonate_argnames,则不会使用 inspect.signature,并且只有在 donate_argnumsdonate_argnames 中列出的实际参数将被捐赠。

    有关缓冲区捐赠的更多详细信息,请参阅 FAQ

  • donate_argnames – 一个可选的字符串或字符串集合,指定哪些命名参数被捐赠给计算。 有关详细信息,请参阅 donate_argnums 的注释。 如果未提供,但设置了 donate_argnums,则默认值基于调用 inspect.signature(fun) 来查找对应的命名参数。

  • keep_unused – 如果为 False(默认值),JAX 确定为 fun 未使用的参数 *可能* 会从生成的已编译 XLA 可执行文件中删除。 此类参数不会被传输到设备,也不会被提供给底层可执行文件。 如果为 True,则不会修剪未使用的参数。

  • device – 这是一个实验性功能,API 可能会发生变化。 可选的,jitted 函数将运行的设备。(可以通过 jax.devices() 获取可用设备。) 默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend – 这是一个实验性功能,API 可能会发生变化。 可选的,一个字符串,表示 XLA 后端:'cpu''gpu''tpu'

  • inline – 指定此函数是否应该内联到封闭的 jaxprs 中(而不是表示为具有其自身子 jaxpr 的 xla_call 原语的应用)。 默认值为 False。

返回值

一个包装后的 fun 版本,设置为进行即时编译。

flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[source]#
flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[source]#
flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[source]#

jax.vmap 的引用感知版本。

参数
  • f – 要在其他轴上映射的函数。

  • in_axes – 一个整数、None 或值序列,指定要映射哪些输入数组轴(见 jax.vmap)。 除了整数和 None 之外,还可以使用 StateAxes 来控制图形节点(如模块)如何通过指定要应用于给定 Filter 的图形节点子状态的轴来进行向量化。

  • out_axes – 一个整数、None 或 pytree,指示映射的轴应出现在输出中的位置(见 jax.vmap)。

  • axis_name – 可选的,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合。

  • axis_size – 可选的,一个整数,指示要映射的轴的大小。 如果未提供,则映射的轴大小将从参数中推断。

返回值

具有与 f 对应的参数的 f 的批处理/向量化版本,但在由 in_axes 指示的位置具有额外的数组轴,以及与 f 对应的返回值,但在由 out_axes 指示的位置具有额外的数组轴。

示例

>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
...   return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
...   def __init__(self, num, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
...   return jnp.dot(x, model.w.value)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)

为了控制图形节点子状态如何进行向量化,可以将 StateAxes 传递给 in_axesout_axes,指定要应用于给定过滤器的每个子状态的轴。 以下示例显示了如何在保持不同批处理统计信息和 dropout 随机状态的同时,在集成成员之间共享参数

>>> class Foo(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.arange(4))
...     self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
...   return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)
flax.nnx.eval_shape(f, *args, **kwargs)[source]#
flax.nnx.cond(pred, true_fun, false_fun, *operands, **kwargs)[source]#