state#

class flax.nnx.State(mapping, /, *, _copy=True)[source]#

一个类似于 pytree 的结构,包含一个从字符串或整数到叶子的 Mapping。有效的叶子类型是 Variablejax.Arraynumpy.ndarray 或嵌套的 State。可以通过调用 split()state()Module 上生成 State

filter(first, /, *filters)[source]#

将一个 State 过滤为一个或多个 State。用户必须至少传递一个 Filter(即 Variable)。此方法类似于 split(),但过滤器可以是非详尽的。

示例用法

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param = state.filter(nnx.Param)
>>> batch_stats = state.filter(nnx.BatchStat)
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
参数
  • first – 第一个过滤器

  • *filters – 可选的,用于将状态分组为互斥子状态的额外过滤器。

返回值

一个或多个 States,等于传递的过滤器数量。

static merge(state, /, *states)[source]#

split() 的逆操作。

merge 接受一个或多个 State,并创建一个新的 State

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> params.linear.bias.value += 1

>>> state = nnx.State.merge(params, batch_stats)
>>> nnx.update(model, state)
>>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
参数
  • state – 一个 State 对象。

  • *states – 额外的 State 对象。

返回值

合并后的 State

split(first, /, *filters)[source]#

将一个 State 拆分为一个或多个 State。用户必须至少传递一个 Filter(即 Variable),并且过滤器必须是详尽的(即它们必须涵盖 State 中的所有 Variable 类型)。

示例用法

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
参数
  • first – 第一个过滤器

  • *filters – 可选的,用于将状态分组为互斥子状态的额外过滤器。

返回值

一个或多个 States,等于传递的过滤器数量。