Flax#

Neural Networks for JAX


Flax 为使用 JAX 和神经网络的研究人员提供了端到端且灵活的用户体验。 Flax 充分发挥了 JAX 的全部潜力。

其核心是Flax NNX,一个简化的 API,它使在 JAX 中创建、检查、调试和分析神经网络变得更容易。 它对 Python 引用语义提供了头等支持,允许用户使用常规 Python 对象表达他们的模型。 Flax NNX 是之前 Flax Linen API 的演变,经过多年的经验,它带来了更简单、更友好的用户体验。

注意

Flax Linen API 不会在短期内被弃用,因为大多数用户仍然依赖于此 API,但是鼓励新用户使用 Flax NNX。 对于想要从 Linen 迁移到 NNX 的现有 Linen 用户,请查看 演变指南

特性#

Pythonic

Flax NNX 支持使用常规 Python 对象,提供直观且可预测的开发体验。

简单

Flax NNX 依赖于 Python 的对象模型,这为用户带来了简便性并提高了开发速度。

表现力强

Flax NNX 通过其 过滤器 系统允许对模型状态进行细粒度控制。

熟悉

Flax NNX 通过 函数式 API 使得将对象与常规 JAX 代码集成变得非常容易。

基本用法#

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3))  # reference sharing

@nnx.jit  # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)  # in-place updates

  return loss

安装#

通过 pip 安装

pip install flax

或从仓库安装最新版本

pip install git+https://github.com/google/flax.git

了解更多#

Flax NNX 基础
nnx_basics.html
MNIST 教程
mnist_tutorial.html
Flax Linen 到 Flax NNX
guides/linen_to_nnx.html
术语表
glossary.html