state#
- class flax.nnx.State(mapping, /, *, _copy=True)[source]#
一个类似于 pytree 的结构,包含一个从字符串或整数到叶子的
Mapping
。有效的叶子类型是Variable
、jax.Array
、numpy.ndarray
或嵌套的State
。可以通过调用split()
或state()
在Module
上生成State
。- filter(first, /, *filters)[source]#
将一个
State
过滤为一个或多个State
。用户必须至少传递一个Filter
(即Variable
)。此方法类似于split()
,但过滤器可以是非详尽的。示例用法
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = state.filter(nnx.Param) >>> batch_stats = state.filter(nnx.BatchStat) >>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
- 参数
first – 第一个过滤器
*filters – 可选的,用于将状态分组为互斥子状态的额外过滤器。
- 返回值
一个或多个
States
,等于传递的过滤器数量。
- static merge(state, /, *states)[source]#
split()
的逆操作。merge
接受一个或多个State
,并创建一个新的State
。示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params.linear.bias.value += 1 >>> state = nnx.State.merge(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
- 参数
state – 一个
State
对象。*states – 额外的
State
对象。
- 返回值
合并后的
State
。
- split(first, /, *filters)[source]#
将一个
State
拆分为一个或多个State
。用户必须至少传递一个Filter
(即Variable
),并且过滤器必须是详尽的(即它们必须涵盖State
中的所有Variable
类型)。示例用法
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
- 参数
first – 第一个过滤器
*filters – 可选的,用于将状态分组为互斥子状态的额外过滤器。
- 返回值
一个或多个
States
,等于传递的过滤器数量。