Flax#
Neural Networks for JAX
Flax 为使用 JAX 和神经网络的研究人员提供了端到端且灵活的用户体验。 Flax 充分发挥了 JAX 的全部潜力。
其核心是Flax NNX,一个简化的 API,它使在 JAX 中创建、检查、调试和分析神经网络变得更容易。 它对 Python 引用语义提供了头等支持,允许用户使用常规 Python 对象表达他们的模型。 Flax NNX 是之前 Flax Linen API 的演变,经过多年的经验,它带来了更简单、更友好的用户体验。
注意
Flax Linen API 不会在短期内被弃用,因为大多数用户仍然依赖于此 API,但是鼓励新用户使用 Flax NNX。 对于想要从 Linen 迁移到 NNX 的现有 Linen 用户,请查看 演变指南。
特性#
基本用法#
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss