variables#
- class flax.nnx.BatchStat(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
在
BatchNorm
层中存储的平均值和方差批次统计信息。请注意,这些不是可学习的比例和偏差参数,而是训练后推理期间通常使用的运行平均统计信息。>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(3,) ), 'mean': VariableState( type=BatchStat, value=(3,) ), 'scale': VariableState( type=Param, value=(3,) ), 'var': VariableState( type=BatchStat, value=(3,) ) })
- class flax.nnx.Cache(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
在
MultiHeadAttention
中的自回归缓存>>> from flax import nnx >>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(0), ... )
>>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ 'cache_index': VariableState( type=Cache, value=() ), 'cached_key': VariableState( type=Cache, value=(1, 2, 3) ), 'cached_value': VariableState( type=Cache, value=(1, 2, 3) ) })
- class flax.nnx.Intermediate(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
Variable
类型,通常用于Module.sow()
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate)) State({ 'i': VariableState( type=Intermediate, value=((1, 3),) ) })
- class flax.nnx.Param(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
规范的可学习参数。NNX 层模块中的所有可学习参数将具有
Param
Variable
类型>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- class flax.nnx.Variable(value, *, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
所有
Variable
类型的基类。通过子类化此类创建自定义的Variable
类型。许多 NNX 图函数可以筛选特定的Variable
类型,例如split()
,state()
,pop()
和State.filter()
。示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> linear_variables = nnx.state(model, nnx.Param) >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ) }) >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } })
- class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[source]#