variables#

class flax.nnx.BatchStat(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

BatchNorm 层中存储的平均值和方差批次统计信息。请注意,这些不是可学习的比例和偏差参数,而是训练后推理期间通常使用的运行平均统计信息。

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(3,)
  ),
  'scale': VariableState(
    type=Param,
    value=(3,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(3,)
  )
})
class flax.nnx.Cache(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

MultiHeadAttention 中的自回归缓存

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
  'cache_index': VariableState(
    type=Cache,
    value=()
  ),
  'cached_key': VariableState(
    type=Cache,
    value=(1, 2, 3)
  ),
  'cached_value': VariableState(
    type=Cache,
    value=(1, 2, 3)
  )
})
class flax.nnx.Intermediate(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

Variable 类型,通常用于 Module.sow()

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x
>>> model = Model(rngs=nnx.Rngs(0))

>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate))
State({
  'i': VariableState(
    type=Intermediate,
    value=((1, 3),)
  )
})
class flax.nnx.Param(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

规范的可学习参数。NNX 层模块中的所有可学习参数将具有 Param Variable 类型

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})
class flax.nnx.Variable(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#

所有 Variable 类型的基类。通过子类化此类创建自定义的 Variable 类型。许多 NNX 图函数可以筛选特定的 Variable 类型,例如 split()state()pop()State.filter()

示例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))

>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})

>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  )
})

>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  ),
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})
class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[source]#
class flax.nnx.VariableState(type, value, **metadata)[source]#
flax.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#