从 Haiku 迁移到 Flax#

本指南将展示 Haiku、Flax Linen 和 Flax NNX 之间的区别。Haiku 和 Linen 都强制执行无状态模块的功能范式,而 NNX 是一个新的下一代 API,它拥抱 Python 语言以提供更直观的开发体验。

基本示例#

要创建自定义模块,您需要在 Haiku 和 Flax 中从 Module 基类继承。模块可以在 Haiku 和 Flax Linen 中内联定义(使用 @nn.compact 装饰器),而模块不能在 NNX 中内联定义,必须在 __init__ 中定义。

Linen 需要一个 deterministic 参数来控制是否使用 dropout。NNX 也使用一个 deterministic 参数,但该值可以使用稍后将展示的 .eval().train() 方法设置。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x
from flax.experimental import nnx

class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x

class Model(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.block = Block(din, dmid, rngs=rngs)
    self.linear = nnx.Linear(dmid, dout, rngs=rngs)


  def __call__(self, x):
    x = self.block(x)
    x = self.linear(x)
    return x

由于模块在 Haiku 和 Linen 中内联定义,因此参数是通过推断样本输入的形状延迟初始化的。在 Flax NNX 中,模块是有状态的,并且是急切初始化的。这意味着必须在模块实例化期间显式传递输入形状,因为 NNX 中没有形状推断。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
...


model = Model(256, 10)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))

要获取 Haiku 和 Linen 中的模型参数,您可以使用 init 方法,该方法使用一个 random.key 加上一些输入来运行模型。

在 NNX 中,模型参数在用户实例化模型时自动初始化,因为输入形状已在实例化时显式传递。

由于 NNX 是急切的,并且模块在实例化时绑定,因此用户可以通过点访问来访问参数(以及在 __init__ 中定义的其他字段)。另一方面,Haiku 和 Linen 使用延迟初始化,因此只有在使用样本输入初始化模块后才能访问参数,并且这两个框架都不支持其属性的点访问。

sample_x = jnp.ones((1, 784))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)


assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
sample_x = jnp.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables["params"]

assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
...




# parameters were already initialized during model instantiation

assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

让我们看一下参数结构。在 Haiku 和 Linen 中,我们可以简单地检查从 .init() 返回的 params 对象。

要查看 NNX 中的参数结构,用户可以调用 nnx.split 来生成 GraphdefState 对象。 Graphdef 是一个表示模型结构的静态 pytree(有关示例用法,请参阅 NNX 基础)。 State 对象包含所有模块变量(即任何子类为 nnx.Variable 的类)。如果我们过滤 nnx.Param,我们将生成一个包含所有可学习模块参数的 State 对象。

...


{
  'model/block/linear': {
    'b': (256,),
    'w': (784, 256),
  },
  'model/linear': {
    'b': (10,),
    'w': (256, 10),
  }
}

...
...


FrozenDict({
  Block_0: {
    Dense_0: {
      bias: (256,),
      kernel: (784, 256),
    },
  },
  Dense_0: {
    bias: (10,),
    kernel: (256, 10),
  },
})
graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState)

params
State({
  'block': {
    'linear': {
      'bias': VariableState(type=Param, value=(256,)),
      'kernel': VariableState(type=Param, value=(784, 256))
    }
  },
  'linear': {
    'bias': VariableState(type=Param, value=(10,)),
    'kernel': VariableState(type=Param, value=(256, 10))
  }
})

在 Haiku 和 Linen 中训练期间,您将参数结构传递给 apply 方法以运行正向传递。要使用 dropout,我们必须传递 training=True 并提供一个 keyapply 以生成随机 dropout 掩码。要使用 NNX 中的 dropout,我们首先调用 model.train(),这将设置 dropout 层的 deterministic 属性为 False(相反,调用 model.eval() 将设置 deterministicTrue)。由于有状态的 NNX 模块已经包含参数和 RNG 密钥(用于 dropout),因此我们只需要调用模块即可运行正向传递。我们使用 nnx.split 提取可学习的参数(所有可学习的参数都是 NNX 类 nnx.Param 的子类),然后使用 nnx.update 应用梯度并有状态更新模型。

要编译 train_step,我们使用 @jax.jit(针对 Haiku 和 Linen)和 @nnx.jit(针对 NNX)来装饰函数。类似于 @jax.jit@nnx.jit 也会编译函数,并具有允许用户编译以 NNX 模块为参数的函数的附加功能。

...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
...

@jax.jit
def train_step(key, params, inputs, labels):
  def loss_fn(params):
    logits = model.apply(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)


  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
model.train() # set deterministic=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = nnx.grad(loss_fn)(model)
  # we can use Ellipsis to filter out the rest of the variables
  _, params, _ = nnx.split(model, nnx.Param, ...)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  nnx.update(model, params)

Flax 还提供了一个方便的 TrainState 数据类,用于捆绑模型、参数和优化器,以简化训练和更新模型。在 Haiku 和 Linen 中,我们只需将 model.apply 函数、初始化参数和优化器作为参数传递给 TrainState 构造函数。

在 NNX 中,我们必须首先在模型上调用 nnx.split 以获取分离的 GraphDefState 对象。我们可以传递 nnx.Param 将所有可训练参数过滤到单个 State 中,并传递 ... 用于其余变量。我们还需要子类化 TrainState 以添加其他变量的字段。然后,我们可以将 GraphDef.apply 作为应用函数、State 作为参数和其他变量以及优化器作为参数传递给 TrainState 构造函数。需要注意的是,GraphDef.apply 将接受 State 作为参数,并返回一个可调用函数。该函数可以在输入上调用以输出模型的 logits,以及更新的 GraphDefState 对象。对于我们当前使用 dropout 的示例,这并不需要,但在下一节中,您将看到使用这些更新的对象与批量归一化等层相关。请注意,我们还使用 @jax.jit,因为我们没有将 NNX 模块传递给 train_step

from flax.training import train_state







state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      params, key,
      inputs, training=True # <== inputs

    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state







state = train_state.TrainState.create(
  apply_fn=model.apply,
  params=params,

  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(key, state, inputs, labels):
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      inputs, training=True, # <== inputs
      rngs={'dropout': key}
    )
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params)


  state = state.apply_gradients(grads=grads)

  return state
from flax.training import train_state

model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)

class TrainState(train_state.TrainState):
  other_variables: nnx.State

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  other_variables=other_variables,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, other_variables):
    logits, (graphdef, new_state) = state.apply_fn(
      params,
      other_variables

    )(inputs) # <== inputs
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(state.params, state.other_variables)


  state = state.apply_gradients(grads=grads)

  return state

处理状态#

现在让我们看看所有三个框架如何处理可变状态。我们将采用与之前相同的模型,但现在我们将用批量归一化替换 Dropout。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features



  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x
class Block(nn.Module):
  features: int




  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.BatchNorm(
      momentum=0.99
    )(x, use_running_average=not training)
    x = jax.nn.relu(x)
    return x
class Block(nnx.Module):
  def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(
      num_features=out_features, momentum=0.99, rngs=rngs
    )

  def __call__(self, x):
    x = self.linear(x)
    x = self.batchnorm(x)


    x = jax.nn.relu(x)
    return x

Haiku 需要一个 is_training 参数,而 Linen 需要一个 use_running_average 参数来控制是否更新运行统计信息。NNX 也使用一个 use_running_average 参数,但该值可以使用稍后将展示的 .eval().train() 方法设置。

与之前一样,您需要传递输入形状以在 NNX 中急切地构造模块。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform_with_state(forward)
...


model = Model(256, 10)
...


model = Model(784, 256, 10, rngs=nnx.Rngs(0))

要初始化 Haiku 和 Linen 中的参数和状态,您只需像以前一样调用 init 方法。但是,在 Haiku 中,您现在将获得 batch_stats 作为第二个返回值,而在 Linen 中,您将获得 variables 字典中一个新的 batch_stats 集合。请注意,由于 hk.BatchNorm 仅在 is_training=True 时初始化批量统计信息,因此我们必须在使用 hk.BatchNorm 层初始化 Haiku 模型的参数时设置 training=True。在 Linen 中,我们可以像往常一样设置 training=False

在 NNX 中,参数和状态在模块实例化时已经初始化。批量统计信息属于 nnx.BatchStat 类,该类是 nnx.Variable 类的子类(不是 nnx.Param,因为它们不是可学习的参数)。调用不带任何其他过滤参数的 nnx.split 将默认返回包含所有 nnx.Variable 的状态。

sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(
  random.key(0),
  sample_x, training=True # <== inputs
)
...
sample_x = jnp.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
...




graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)

现在,Haiku 和 Linen 中的训练看起来非常相似,因为您使用相同的 apply 方法来运行正向传递。在 Haiku 中,现在将 batch_stats 作为第二个参数传递给 apply,并将新更新的 batch_stats 作为第二个返回值。在 Linen 中,您改为将 batch_stats 添加为输入字典中的一个新键,并将 updates 变量字典作为第二个返回值。要更新批量统计信息,我们必须将 training=True 传递给 apply

在 NNX 中,训练代码与之前的示例相同,因为批次统计信息(与有状态的 NNX 模块绑定)是有状态更新的。为了在 NNX 中更新批次统计信息,我们首先调用 model.train(),这将把 batchnorm 层的 use_running_average 属性设置为 False(反之,调用 model.eval() 将把 use_running_average 设置为 True)。由于有状态的 NNX 模块已经包含参数和批次统计信息,我们只需调用该模块来运行前向传播。我们使用 nnx.split 来提取可学习参数(所有可学习参数都是 NNX 类 nnx.Param 的子类),然后使用 nnx.update 应用梯度并有状态地更新模型。

...

@jax.jit
def train_step(params, batch_stats, inputs, labels):
  def loss_fn(params, batch_stats):
    logits, batch_stats = model.apply(
      params, batch_stats,
      None, # <== rng
      inputs, training=True # <== inputs
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, batch_stats

  grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)

  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, batch_stats
...

@jax.jit
def train_step(params, batch_stats, inputs, labels):
  def loss_fn(params, batch_stats):
    logits, updates = model.apply(
      {'params': params, 'batch_stats': batch_stats},
      inputs, training=True, # <== inputs
      mutable='batch_stats',
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, updates["batch_stats"]

  grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)

  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, batch_stats
model.train() # set use_running_average=False

@nnx.jit
def train_step(model, inputs, labels):
  def loss_fn(model):
    logits = model(

      inputs, # <== inputs

    ) # batch statistics are updated statefully in this step
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss

  grads = nnx.grad(loss_fn)(model)
  _, params, _ = nnx.split(model, nnx.Param, ...)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  nnx.update(model, params)

为了使用 TrainState,我们子类化以添加一个可以存储批次统计信息的额外字段。

...


class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, batch_stats):
    logits, batch_stats = state.apply_fn(
      params, batch_stats,
      None, # <== rng
      inputs, training=True # <== inputs
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, batch_stats

  grads, batch_stats = jax.grad(
    loss_fn, has_aux=True
  )(state.params, state.batch_stats)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=batch_stats)

  return state
...


class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, batch_stats):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': batch_stats},
      inputs, training=True, # <== inputs
      mutable='batch_stats'
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, updates['batch_stats']

  grads, batch_stats = jax.grad(
    loss_fn, has_aux=True
  )(state.params, state.batch_stats)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=batch_stats)

  return state
model.train() # set deterministic=False
graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)

class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=graphdef.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3)
)

@jax.jit
def train_step(state, inputs, labels):
  def loss_fn(params, batch_stats):
    logits, (graphdef, new_state) = state.apply_fn(
      params, batch_stats
    )(inputs) # <== inputs

    _, batch_stats = new_state.split(nnx.Param, nnx.BatchStat)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, batch_stats

  grads, batch_stats = jax.grad(
    loss_fn, has_aux=True
  )(state.params, state.batch_stats)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=batch_stats)

  return state

使用多种方法#

在本节中,我们将看看如何在所有三个框架中使用多种方法。作为示例,我们将实现一个具有三种方法的自动编码器模型:encodedecode__call__

与之前一样,我们定义编码器和解码器层,无需传入输入形状,因为模块参数将在 Haiku 和 Linen 中使用形状推断进行延迟初始化。在 NNX 中,我们必须传入输入形状,因为模块参数将在没有形状推断的情况下被急切地初始化。

class AutoEncoder(hk.Module):


  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x
class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x
class AutoEncoder(nnx.Module):



  def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
    self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
    self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

与之前一样,我们在实例化 NNX 模块时传入输入形状。

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
...




model = AutoEncoder(256, 784)
...




model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))

对于 Haiku 和 Linen,init 可用于触发 __call__ 方法以初始化模型的参数,该模型同时使用 encodedecode 方法。这将创建模型所需的所有必要参数。在 NNX 中,参数在模块实例化时已初始化。

params = model.init(
  random.key(0),
  x=jnp.ones((1, 784)),
)
params = model.init(
  random.key(0),
  x=jnp.ones((1, 784)),
)['params']
# parameters were already initialized during model instantiation


...

参数结构如下:

...


{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
...


FrozenDict({
    decoder: {
        bias: (784,),
        kernel: (256, 784),
    },
    encoder: {
        bias: (256,),
        kernel: (784, 256),
    },
})
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'decoder': {
    'bias': VariableState(type=Param, value=(784,)),
    'kernel': VariableState(type=Param, value=(256, 784))
  },
  'encoder': {
    'bias': VariableState(type=Param, value=(256,)),
    'kernel': VariableState(type=Param, value=(784, 256))
  }
})

最后,让我们探索如何执行前向传播。在 Haiku 和 Linen 中,我们使用 apply 函数来调用 encode 方法。在 NNX 中,我们可以直接调用 encode 方法。

encode, decode = model.apply
z = encode(
  params,
  None, # <== rng
  x=jnp.ones((1, 784)),

)
...
z = model.apply(
  {"params": params},

  x=jnp.ones((1, 784)),
  method="encode",
)
...
z = model.encode(jnp.ones((1, 784)))




...

提升转换#

Flax 和 Haiku 都提供了一组转换,我们将其称为提升转换,它们以这样一种方式包装 JAX 转换:它们可以与模块一起使用,有时还可以提供额外的功能。在本节中,我们将看看如何在 Flax 和 Haiku 中使用 scan 的提升版本来实现一个简单的 RNN 层。

首先,我们将定义一个 RNNCell 模块,它将包含 RNN 单步逻辑。我们还将定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(也称为 carry)。与 jax.lax.scan 一样,RNNCell.__call__ 方法将是一个接收 carry 和输入并返回新 carry 和输出的函数。在这种情况下,carry 和输出相同。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
  def __init__(self, input_size, hidden_size, rngs):
    self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = self.linear(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个 RNN 模块,它将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell,然后用它来构造 carry,最后使用 hk.scan 在输入序列上运行 RNNCell

在 Linen 中,我们将使用 nn.scan 定义一个新的临时类型,它将包装 RNNCell。在此过程中,我们还将指定指示 nn.scan 广播 params 集合(所有步骤共享相同的参数)并且不拆分 params rng 流(以便所有步骤都使用相同的参数初始化),最后我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出沿第二个轴堆叠起来。然后我们将立即使用这个临时类型来创建一个提升的 RNNCell 实例,并用它来创建 carry,并运行 __call__ 方法,它将在序列上进行 scan

在 NNX 中,我们定义了一个 scan 函数 scan_fn,它将使用在 __init__ 中定义的 RNNCell 来扫描序列。

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(
      cell, carry,
      jnp.swapaxes(x, 1, 0)
    )
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(
      RNNCell, variable_broadcast='params',
      split_rngs={'params': False}, in_axes=1, out_axes=1
    )(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)

    return y
class RNN(nnx.Module):
  def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
    self.hidden_size = hidden_size
    self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)

  def __call__(self, x):
    scan_fn = lambda carry, cell, x: cell(carry, x)
    carry = self.cell.initial_state(x.shape[0])
    carry, y = nnx.scan(
      scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
    )(carry, self.cell, x)

    return y

通常,Flax 和 Haiku 之间的提升转换的主要区别在于,在 Haiku 中,提升转换不操作状态,也就是说,Haiku 将处理 paramsstate,使其在转换内部和外部保持相同的形状。在 Flax 中,提升转换可以同时操作变量集合和 rng 流,用户必须根据转换的语义定义每个转换如何处理不同的集合。

与之前示例一样,参数必须通过 .init() 初始化并传递到 .apply() 以在 Haiku 和 Linen 中进行前向传播。在 NNX 中,参数已被急切地初始化并绑定到有状态模块,并且该模块可以简单地调用输入以进行前向传播。

x = jnp.ones((3, 12, 32))

def forward(x):
  return RNN(64)(x)

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(
  random.key(0),
  x=jnp.ones((3, 12, 32)),
)

y = model.apply(
  params,
  x=jnp.ones((3, 12, 32)),
)
x = jnp.ones((3, 12, 32))




model = RNN(64)

params = model.init(
  random.key(0),
  x=jnp.ones((3, 12, 32)),
)['params']

y = model.apply(
  {'params': params},
  x=jnp.ones((3, 12, 32)),
)
x = jnp.ones((3, 12, 32))




model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))






y = model(x)


...

与前几节中的示例相比,唯一值得注意的变化是,这次我们在 Haiku 中使用了 hk.without_apply_rng,因此我们不必将 rng 参数作为 None 传递到 apply 方法。

扫描层#

scan 的一个非常重要的应用是,将一系列层迭代地应用于输入,并将每层的输出作为下一层的输入。这对于减少大型模型的编译时间非常有用。作为示例,我们将创建一个简单的 Block 模块,然后在将 Block 模块应用 num_layers 次的 MLP 模块中使用它。

在 Haiku 中,我们按照通常的方式定义 Block 模块,然后在 MLP 内部,我们将使用 hk.experimental.layer_stack 在一个 stack_block 函数上创建一个 Block 模块堆栈。

在 Linen 中,Block 的定义略有不同,__call__ 将接受并返回一个第二个虚拟输入/输出,在这两种情况下都将是 None。在 MLP 中,我们将使用 nn.scan,如前一个示例中一样,但通过设置 split_rngs={'params': True}variable_axes={'params': 0},我们告诉 nn.scan 为每个步骤创建不同的参数,并在第一个轴上切片 params 集合,有效地实现了一个 Block 模块堆栈,如 Haiku 中所示。

在 NNX 中,我们使用 nnx.Scan.constructor() 定义一个 Block 模块堆栈。然后我们可以简单地调用 Block 的堆栈 self.blocks,在输入和 carry 上,以获得前向传播输出。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers



  def __call__(self, x, training: bool):
    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)
class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int




  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y
class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array, _):
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x, None

class MLP(nnx.Module):
  def __init__(self, input_dim, features, num_layers, rngs):
    self.blocks = nnx.Scan.constructor(
      Block, length=num_layers
    )(input_dim, features, rngs=rngs)



  def __call__(self, x):




    y, _ = self.blocks(x, None)
    return y

请注意,在 Flax 中,我们将 None 作为第二个参数传递给 ScanBlock 并忽略其第二个输出。这些表示每步的输入/输出,但它们是 None,因为在这种情况下我们没有任何输入/输出。

初始化每个模型与之前的示例相同。在这种情况下,我们将指定要使用 5 个层,每个层都有 64 个特征。与之前一样,我们还为 NNX 传入输入形状。

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)

model = hk.transform(forward)

sample_x = jnp.ones((1, 64))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...


model = MLP(64, num_layers=5)

sample_x = jnp.ones((1, 64))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)['params']
...


model = MLP(64, 64, num_layers=5, rngs=nnx.Rngs(0))





...

使用层扫描时,您应该注意的一件事是,所有层都被融合成单个层,其参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...) 开头,因为我们使用了 5 个层。

...


{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}



...
...


FrozenDict({
    ScanBlock_0: {
        Dense_0: {
            bias: (5, 64),
            kernel: (5, 64, 64),
        },
    },
})

...
_, params, _ = nnx.split(model, nnx.Param, ...)

params
State({
  'blocks': {
    'scan_module': {
      'linear': {
        'bias': VariableState(type=Param, value=(5, 64)),
        'kernel': VariableState(type=Param, value=(5, 64, 64))
      }
    }
  }
})

顶层 Haiku 函数与顶层 Flax 模块#

在 Haiku 中,可以使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数非常常见。

Flax 团队推荐使用更以模块为中心的方案,该方案使用 __call__ 来定义前向函数。在 Linen 中,相应的访问器将是 Module.paramModule.variable(有关集合的解释,请参见处理状态)。在 NNX 中,可以使用常规的 Python 类语义来设置和访问参数和变量。

...


def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter(
    'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
  )

  output = x + multiplier * counter

  hk.set_state("counter", counter + 1)
  return output

model = hk.transform_with_state(forward)

params, state = model.init(random.key(0), jnp.ones((1, 64)))
...


class FooModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
    multiplier = self.param(
      'multiplier', nn.initializers.ones_init(), [1,], x.dtype
    )

    output = x + multiplier * counter.value
    if not self.is_initializing():  # otherwise model.init() also increases it
      counter.value += 1
    return output

model = FooModule()
variables = model.init(random.key(0), jnp.ones((1, 64)))
params, counter = variables['params'], variables['counter']
class Counter(nnx.Variable):
  pass

class FooModule(nnx.Module):

  def __init__(self, rngs):
    self.counter = Counter(jnp.ones((), jnp.int32))
    self.multiplier = nnx.Param(
      nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
    )
  def __call__(self, x):
    output = x + self.multiplier * self.counter.value

    self.counter.value += 1
    return output

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)