转换#

一般的 JAX 转换在 Pytrees 数组上运行并遵循值语义,这对 Flax NNX 来说是一个挑战,Flax NNX 将模块表示为遵循引用语义的常规 Python 对象。为了解决这个问题,Flax NNX 引入了自己的转换集,这些转换扩展了 JAX 转换,允许模块和其他 Flax NNX 对象在转换中传入和传出,同时保留引用语义。

Flax NNX 转换对那些以前使用过 JAX 转换的人来说应该很熟悉,因为它们使用相同的 API 并且在仅处理数组的 Pytrees 时表现得像 JAX 转换。但是,在处理 Flax NNX 对象时,它们允许 Python 的引用语义保留在这些对象中,这包括

  • 保留转换输入和输出中多个对象之间的共享引用。

  • 将对转换内部对象进行的任何状态更改传播到转换外部的对象。

  • 在多个输入和输出中存在别名时,强制执行对象转换方式的一致性。

import jax
from jax import numpy as jnp, random
from flax import nnx

在本指南中,我们将使用 nnx.vmap 作为案例研究,以演示 Flax NNX 转换的工作原理,但此处概述的原则扩展到所有转换。

基本示例#

首先,让我们看看一个使用 nnx.vmap 来扩展按元素的 vector_dot 函数以在批处理输入上工作的简单示例。我们将定义一个不带方法的 Weights 模块来保存一些参数,这些权重将作为输入传递给 vector_dot 函数以及一些数据。权重和数据都将在轴 0 上进行批处理,我们将使用 nnx.vmapvector_dot 应用于每个批处理元素,结果将在轴 1 上进行批处理

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
)
x = jax.random.normal(random.key(1), (10, 2))

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

请注意,in_axes 自然地与 Weights 模块交互,将其视为数组的 Pytree。也允许使用前缀模式,in_axes=(0, 0) 在这种情况下也能工作。

对象也允许作为 Flax NNX 转换的输出,这对于转换初始化器非常有用。例如,我们可以定义一个 create_weights 函数来创建一个单一的 Weights 模块,并使用 nnx.vmap 来创建一个与以前形状相同的 Weights 堆栈

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights)(seeds)
nnx.display(weights)

转换方法#

Python 中的方法只是将实例作为第一个参数的函数,这意味着您可以从 Module 和其他 Flax NNX 子类型装饰方法。例如,我们可以重构前面的示例中的 Weights,并用 vmap 装饰 __init__ 来完成 create_weights 的工作,并添加一个 __call__ 方法并用 vmap 装饰它来完成 vector_dot 的工作

class WeightStack(nnx.Module):
  @nnx.vmap
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=0, out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)

print(f'{y.shape = }')
nnx.display(weights)
y.shape = (3, 10)

在本指南的其余部分,我们将重点介绍转换单个函数,但是,请注意所有示例都可以很容易地用这种方法风格编写。

状态传播#

到目前为止,我们的函数都是无状态的。但是,Flax NNX 转换的真正强大之处在于我们有状态函数,因为它们的主要特征之一是传播状态更改以保留引用语义。让我们通过向 Weights 添加一个 count 属性并在新的 stateful_vector_dot 函数中递增它来更新我们的示例。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias


y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)

weights.count
Count(
  value=Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32)
)

在运行 stateful_vector_dot 一次后,我们验证 count 属性是否已正确更新。由于 Weights 是向量化的,count 被初始化为一个 arange(10),并且它的所有元素在转换中都递增了 1。最重要的是,更新传播到了转换外部的原始 Weights 对象。很好!

图形更新传播#

JAX 转换将输入视为数组的 Pytrees,而 Flax NNX 将输入视为数组的 Pytrees 和 Python 引用,引用形成一个图。Flax NNX 的状态传播机制可以跟踪对象中的任意更新,只要它们是输入的本地更新(不支持转换内部全局的更新)。这意味着您可以根据需要修改图形结构,包括更新现有属性、添加/删除属性、交换属性、在对象之间共享(新)引用、在对象之间共享变量等。一切皆有可能!

以下示例演示了在 nnx.vmap 内部对 Weights 对象进行一些任意更新,并验证更新是否已正确传播到转换外部的原始 Weights 对象。

class Count(nnx.Variable): pass

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.arange(10),
)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  weights.some_property = ['a', 2, False] # add attribute
  del weights.bias # delete attribute
  weights.new_param = weights.kernel # share reference
  return y

y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)

nnx.display(weights)

能力越大,责任越大。
- 叔叔本

虽然这个功能非常强大,但必须谨慎使用,因为它可能会与某些转换中 JAX 的底层假设冲突。例如,jit 期望输入的结构稳定,以便缓存已编译的函数,因此在 nnx.jit 编制的函数内部更改图形结构会导致持续的重新编译和性能下降。另一方面,scan 仅允许固定 carry 结构,因此添加/删除声明为 carry 的子状态会导致错误。

转换子状态(提升类型)#

某些 JAX 转换允许使用 Pytree 前缀来指定如何转换输入/输出的不同部分。Flax NNX 支持 Pytree 前缀以用于 Pytree 结构,但目前它没有针对图形对象的缀的概念。相反,Flax NNX 引入了 Lift Types 的概念,它允许指定如何转换对象的不同子状态。不同的转换支持不同的提升类型,以下是每个转换当前支持的提升类型列表

提升类型

转换

StateAxes

vmap, pmap, scan

StateSharding

jit, shard_map

DiffState

grad, value_and_grad, custom_vjp

注意:shard_map 尚未实现。

如果我们想要指定如何在 nnx.vmap 中向量化对象的不同的子状态,我们创建一个 StateAxes,它通过 过滤器 将一组子状态映射到它们的相应轴,并将 StateAxes 传递给 in_axesout_axes,就好像它是一个 Pytree 前缀一样。让我们使用前面的 stateful_vector_dot 示例,并且只向量化 Param 变量,并广播 count 变量,以便我们只为所有批处理元素保留一个计数。为此,我们将定义一个 StateAxes,其中包含一个与 Param 变量匹配并将其映射到轴 0 的过滤器,以及所有 Count 变量映射到 None 的过滤器,并将此 StateAxes 传递给 Weights 对象的 in_axes

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)

weights = Weights(
  kernel=random.uniform(random.key(0), (10, 2, 3)),
  bias=jnp.zeros((10, 3)),
  count=jnp.array(0),
)
x = jax.random.normal(random.key(1), (10, 2))


def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)

weights.count
Count(
  value=Array(1, dtype=int32, weak_type=True)
)

这里 count 现在是一个标量,因为它没有被向量化。另外,请注意 StateAxes 只能直接用于 Flax NNX 对象,不能用作对象 Pytree 的前缀。

随机状态#

在 Flax NNX 中,随机状态只是普通状态。这意味着它存储在需要它的模块内部,并且被视为任何其他类型的状态。与 Flax Linen 中使用单独机制处理随机状态相比,这是一个简化。在实践中,模块只需要保留对 Rngs 对象的引用,该对象在初始化期间传递给它们,并使用它为每个随机操作生成唯一的密钥。对于本指南的目的,这意味着随机状态可以像任何其他类型的状态一样进行转换,但我们也需要了解状态的布局方式,以便能够正确地转换它。

假设我们想改变一下,将相同的权重应用于批次中的所有元素,但我们希望为每个元素添加不同的随机噪声。为此,我们将向 Weights 添加一个 Rngs 属性,该属性从构造期间传入的 seed 密钥参数创建,此种子密钥必须事先 split,以便我们能够成功地对其进行矢量化。出于教学目的,我们将种子密钥分配给 noise 流并从中进行采样。为了对 RNG 状态进行矢量化,我们必须将 StateAxes 配置为将所有 RngStateRngs 中所有变量的基类)映射到轴 0,并将 ParamCount 映射到 None

class Weights(nnx.Module):
  def __init__(self, kernel, bias, count, seed):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
    self.count = Count(count)
    self.rngs = nnx.Rngs(noise=seed)

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=random.split(random.key(0), num=10),
)
x = random.normal(random.key(1), (10, 2))

def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})
y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)
y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

由于 Rngs 的状态是在原地更新的,并由 nnx.vmap 自动传播,因此每次调用 noisy_vector_dot 时,我们将获得不同的结果。

在上面的示例中,我们在构造期间手动拆分了随机状态,这很好,因为它使意图清晰,但它也不允许我们在 vmap 之外使用 Rngs,因为它的状态始终处于拆分状态。为了解决这个问题,我们传递一个未拆分的种子,并在 vmap 之前使用 nnx.split_rngs 装饰器,在每次调用函数之前拆分 RngState,然后将其“降低”回来,以便它可以使用。

weights = Weights(
  kernel=random.uniform(random.key(0), (2, 3)),
  bias=jnp.zeros((3,)),
  count=jnp.array(0),
  seed=0,
)
x = random.normal(random.key(1), (10, 2))

state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})

@nnx.split_rngs(splits=10)
@nnx.vmap(in_axes=(state_axes, 0))
def noisy_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  y = x @ weights.kernel + weights.bias
  return y + random.normal(weights.rngs.noise(), y.shape)

y1 = noisy_vector_dot(weights, x)
y2 = noisy_vector_dot(weights, x)

print(jnp.allclose(y1, y2))
nnx.display(weights)
False

一致的别名#

允许在转换中使用引用语义的主要问题是引用可以在输入和输出之间共享,如果处理不当,这会导致定义不明确或行为不一致。在下面的示例中,我们有一个名为 m 的单个 Weights 模块,它的引用在 arg1arg2 中的多个地方出现。问题是我们还指定要沿轴 0arg1 进行矢量化,沿轴 1arg2 进行矢量化,这在 JAX 中是正常的,因为 pytree 的引用透明性,但在 Flax NNX 中存在问题,因为我们试图以两种不同的方式对 m 进行矢量化。NNX 将通过抛出错误来强制一致性。

class Weights(nnx.Module):
  def __init__(self, array: jax.Array):
    self.param = nnx.Param(array)

m = Weights(jnp.arange(10))
arg1 = {'a': {'b': m}, 'c': m}
arg2 = [(m, m), m]

@nnx.vmap(in_axes=(0, 1))
def f(arg1, arg2):
  ...

try:
  f(arg1, arg2)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variables.Param'>
  param: 0
  param: 0
  param: 1
Node: <class '__main__.Weights'>
  <root>: 0
  <root>: 0
  <root>: 1

不一致的别名也可能发生在输入和输出之间。在下一个示例中,我们有一个简单的函数,它接受并立即返回 arg1,但是 arg1 在输入时沿轴 0 进行矢量化,在输出时沿轴 1 进行矢量化。正如预期的那样,这是有问题的,Flax NNX 将抛出错误。

@nnx.vmap(in_axes=0, out_axes=1)
def f(arg1):
  return arg1

try:
  f(arg1)
except ValueError as e:
  print(e)
Inconsistent aliasing detected. The following nodes have different prefixes:
Node: <class 'flax.nnx.variables.Param'>
  param: 0
  param: 0
  param: 1
Node: <class '__main__.Weights'>
  <root>: 0
  <root>: 0
  <root>: 1

轴元数据#

Flax NNX 变量可以保存任意元数据,可以通过将它们作为关键字参数传递给它们的构造函数来添加。这通常用于存储 sharding 信息,这些信息由 nnx.spmd API(如 nnx.get_partition_specnnx.get_named_sharding)使用。但是,在涉及转换时,通常需要将此轴相关信息与轴的实际状态保持同步,例如,如果我们沿轴 1 对变量进行矢量化,我们应该在 vmapscan 内部删除位置 1 处的 sharding 信息,以反映轴被暂时移除的事实。为了实现这一点,Flax NNX 转换提供了一个非标准的 transform_metadata 字典参数,当存在 nnx.PARTITION_NAME 密钥时,sharding 元数据将根据 in_axesout_axes 指定进行更新。让我们看一个实际的例子

class Weights(nnx.Module):
  def __init__(self, array: jax.Array, sharding: tuple[str | None, ...]):
    self.param = nnx.Param(array, sharding=sharding)

m = Weights(jnp.ones((3, 4, 5)), sharding=('a', 'b', None))

@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
def f(m: Weights):
  print(f'Inner {m.param.shape = }')
  print(f'Inner {m.param.sharding = }')

f(m)
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Inner m.param.shape = (3, 5)
Inner m.param.sharding = ('a', None)
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)

在这里,我们向 Param 变量添加了一个 sharding 元数据,并使用 transform_metadata 更新 sharding 元数据以反映轴的变化,具体来说,我们可以看到第一个轴 bvmap 内部被从 sharding 元数据中删除,然后在 vmap 外部添加回来。

我们可以验证,当在转换内部创建模块时,这也适用,新的 sharding 轴将被添加到模块的变量之外,与转换的变量的轴匹配。

@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
def init_vmap():
  return Weights(jnp.ones((3, 5)), sharding=('a', None))

m = init_vmap()
print(f'Outter {m.param.shape = }')
print(f'Outter {m.param.sharding = }')
Outter m.param.shape = (3, 4, 5)
Outter m.param.sharding = ('a', 'b', None)