Optimizer#
- class flax.nnx.optimizer.Optimizer(*args, **kwargs)#
用于具有单个 Optax 优化器的常见情况的简单训练状态。
示例用法
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads) >>> loss_fn(model) Array(1.6925814, dtype=float32)
请注意,您可以通过子类化此类来轻松扩展它,以存储其他数据(例如,添加指标)。
示例用法
>>> class TrainState(nnx.Optimizer): ... def __init__(self, model, tx, metrics): ... self.metrics = metrics ... super().__init__(model, tx) ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) ... >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.68612, dtype=float32)
对于更奇特的用例(例如,多个优化器),最好复制此类并对其进行修改。
- step#
一个
OptState
Variable
,用于跟踪步数。
- model#
包装的
Module
。
- tx#
Optax 梯度变换。
- opt_state#
Optax 优化器状态。
- __init__(model, tx, wrt=<class 'flax.nnx.variables.Param'>)#
实例化此类并包装
Module
和 Optax 梯度变换。实例化优化器状态以跟踪Variable
在wrt
中指定的数据类型。将步数设置为 0。- 参数
model – 一个 NNX 模块。
tx – 一个 Optax 梯度变换。
wrt – 可选参数,用于过滤在优化器状态中跟踪哪些
Variable
。这些应该是您计划更新的Variable
;也就是说,此参数值应该与传递给nnx.grad
调用的wrt
参数匹配,该调用将生成将传递到update()
方法的grads
参数的梯度。
- update(grads)#
更新返回值中的
step
、params
、opt_state
和**kwargs
。必须从nnx.grad(..., wrt=self.wrt)
中派生grads
,其中梯度相对于在实例化此Optimizer
期间在self.wrt
中定义的相同Variable
类型。例如>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(model)) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> # update: >>> # - only Linear layer parameters >>> # - only CustomVariable parameters >>> # - both Linear layer and CustomVariable parameters >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)): ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad` ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) ... state.update(grads=grads)
请注意,在内部,此函数调用
.tx.update()
,然后调用optax.apply_updates()
来更新params
和opt_state
。- 参数
grads – 从
nnx.grad
派生的梯度。