Optimizer#

class flax.nnx.optimizer.Optimizer(*args, **kwargs)#

用于具有单个 Optax 优化器的常见情况的简单训练状态。

示例用法

>>> import jax, jax.numpy as jnp
>>> from flax import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> state = nnx.Optimizer(model, tx)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(1.7055722, dtype=float32)
>>> grads = nnx.grad(loss_fn)(state.model)
>>> state.update(grads)
>>> loss_fn(model)
Array(1.6925814, dtype=float32)

请注意,您可以通过子类化此类来轻松扩展它,以存储其他数据(例如,添加指标)。

示例用法

>>> class TrainState(nnx.Optimizer):
...   def __init__(self, model, tx, metrics):
...     self.metrics = metrics
...     super().__init__(model, tx)
...   def update(self, *, grads, **updates):
...     self.metrics.update(**updates)
...     super().update(grads)
...
>>> metrics = nnx.metrics.Average()
>>> state = TrainState(model, tx, metrics)
...
>>> grads = nnx.grad(loss_fn)(state.model)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.6925814, dtype=float32)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.68612, dtype=float32)

对于更奇特的用例(例如,多个优化器),最好复制此类并对其进行修改。

step#

一个 OptState Variable,用于跟踪步数。

model#

包装的 Module

tx#

Optax 梯度变换。

opt_state#

Optax 优化器状态。

__init__(model, tx, wrt=<class 'flax.nnx.variables.Param'>)#

实例化此类并包装 Module 和 Optax 梯度变换。实例化优化器状态以跟踪 Variablewrt 中指定的数据类型。将步数设置为 0。

参数
  • model – 一个 NNX 模块。

  • tx – 一个 Optax 梯度变换。

  • wrt – 可选参数,用于过滤在优化器状态中跟踪哪些 Variable。这些应该是您计划更新的 Variable;也就是说,此参数值应该与传递给 nnx.grad 调用的 wrt 参数匹配,该调用将生成将传递到 update() 方法的 grads 参数的梯度。

update(grads)#

更新返回值中的 stepparamsopt_state**kwargs。必须从 nnx.grad(..., wrt=self.wrt) 中派生 grads,其中梯度相对于在实例化此 Optimizer 期间在 self.wrt 中定义的相同 Variable 类型。例如

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(model))
State({
  'custom_variable': VariableState(
    type=CustomVariable,
    value=(1, 3)
  ),
  'linear': {
    'bias': VariableState(
      type=Param,
      value=(3,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(2, 3)
    )
  }
})

>>> # update:
>>> # - only Linear layer parameters
>>> # - only CustomVariable parameters
>>> # - both Linear layer and CustomVariable parameters
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)):
...   # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad`
...   state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
...   grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
...     state.model, jnp.ones((1, 2)), jnp.ones((1, 3))
...   )
...   state.update(grads=grads)

请注意,在内部,此函数调用 .tx.update(),然后调用 optax.apply_updates() 来更新 paramsopt_state

参数

grads – 从 nnx.grad 派生的梯度。