从 Haiku 迁移到 Flax#
本指南将展示 Haiku、Flax Linen 和 Flax NNX 之间的区别。Haiku 和 Linen 都强制执行无状态模块的功能范式,而 NNX 是一个新的下一代 API,它拥抱 Python 语言以提供更直观的开发体验。
基本示例#
要创建自定义模块,您需要在 Haiku 和 Flax 中从 Module
基类继承。模块可以在 Haiku 和 Flax Linen 中内联定义(使用 @nn.compact
装饰器),而模块不能在 NNX 中内联定义,必须在 __init__
中定义。
Linen 需要一个 deterministic
参数来控制是否使用 dropout。NNX 也使用一个 deterministic
参数,但该值可以使用稍后将展示的 .eval()
和 .train()
方法设置。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax.experimental import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
由于模块在 Haiku 和 Linen 中内联定义,因此参数是通过推断样本输入的形状延迟初始化的。在 Flax NNX 中,模块是有状态的,并且是急切初始化的。这意味着必须在模块实例化期间显式传递输入形状,因为 NNX 中没有形状推断。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
...
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
要获取 Haiku 和 Linen 中的模型参数,您可以使用 init
方法,该方法使用一个 random.key
加上一些输入来运行模型。
在 NNX 中,模型参数在用户实例化模型时自动初始化,因为输入形状已在实例化时显式传递。
由于 NNX 是急切的,并且模块在实例化时绑定,因此用户可以通过点访问来访问参数(以及在 __init__
中定义的其他字段)。另一方面,Haiku 和 Linen 使用延迟初始化,因此只有在使用样本输入初始化模块后才能访问参数,并且这两个框架都不支持其属性的点访问。
sample_x = jnp.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
sample_x = jnp.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
...
# parameters were already initialized during model instantiation
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
让我们看一下参数结构。在 Haiku 和 Linen 中,我们可以简单地检查从 .init()
返回的 params
对象。
要查看 NNX 中的参数结构,用户可以调用 nnx.split
来生成 Graphdef
和 State
对象。 Graphdef
是一个表示模型结构的静态 pytree(有关示例用法,请参阅 NNX 基础)。 State
对象包含所有模块变量(即任何子类为 nnx.Variable
的类)。如果我们过滤 nnx.Param
,我们将生成一个包含所有可学习模块参数的 State
对象。
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState)
params
State({
'block': {
'linear': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
},
'linear': {
'bias': VariableState(type=Param, value=(10,)),
'kernel': VariableState(type=Param, value=(256, 10))
}
})
在 Haiku 和 Linen 中训练期间,您将参数结构传递给 apply
方法以运行正向传递。要使用 dropout,我们必须传递 training=True
并提供一个 key
给 apply
以生成随机 dropout 掩码。要使用 NNX 中的 dropout,我们首先调用 model.train()
,这将设置 dropout 层的 deterministic
属性为 False
(相反,调用 model.eval()
将设置 deterministic
为 True
)。由于有状态的 NNX 模块已经包含参数和 RNG 密钥(用于 dropout),因此我们只需要调用模块即可运行正向传递。我们使用 nnx.split
提取可学习的参数(所有可学习的参数都是 NNX 类 nnx.Param
的子类),然后使用 nnx.update
应用梯度并有状态更新模型。
要编译 train_step
,我们使用 @jax.jit
(针对 Haiku 和 Linen)和 @nnx.jit
(针对 NNX)来装饰函数。类似于 @jax.jit
,@nnx.jit
也会编译函数,并具有允许用户编译以 NNX 模块为参数的函数的附加功能。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params, key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # set deterministic=False
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
# we can use Ellipsis to filter out the rest of the variables
_, params, _ = nnx.split(model, nnx.Param, ...)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, params)
Flax 还提供了一个方便的 TrainState
数据类,用于捆绑模型、参数和优化器,以简化训练和更新模型。在 Haiku 和 Linen 中,我们只需将 model.apply
函数、初始化参数和优化器作为参数传递给 TrainState
构造函数。
在 NNX 中,我们必须首先在模型上调用 nnx.split
以获取分离的 GraphDef
和 State
对象。我们可以传递 nnx.Param
将所有可训练参数过滤到单个 State
中,并传递 ...
用于其余变量。我们还需要子类化 TrainState
以添加其他变量的字段。然后,我们可以将 GraphDef.apply
作为应用函数、State
作为参数和其他变量以及优化器作为参数传递给 TrainState
构造函数。需要注意的是,GraphDef.apply
将接受 State
作为参数,并返回一个可调用函数。该函数可以在输入上调用以输出模型的 logits,以及更新的 GraphDef
和 State
对象。对于我们当前使用 dropout 的示例,这并不需要,但在下一节中,您将看到使用这些更新的对象与批量归一化等层相关。请注意,我们还使用 @jax.jit
,因为我们没有将 NNX 模块传递给 train_step
。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
params, key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
params,
other_variables
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state
处理状态#
现在让我们看看所有三个框架如何处理可变状态。我们将采用与之前相同的模型,但现在我们将用批量归一化替换 Dropout。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
x = jax.nn.relu(x)
return x
Haiku 需要一个 is_training
参数,而 Linen 需要一个 use_running_average
参数来控制是否更新运行统计信息。NNX 也使用一个 use_running_average
参数,但该值可以使用稍后将展示的 .eval()
和 .train()
方法设置。
与之前一样,您需要传递输入形状以在 NNX 中急切地构造模块。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
...
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
要初始化 Haiku 和 Linen 中的参数和状态,您只需像以前一样调用 init
方法。但是,在 Haiku 中,您现在将获得 batch_stats
作为第二个返回值,而在 Linen 中,您将获得 variables
字典中一个新的 batch_stats
集合。请注意,由于 hk.BatchNorm
仅在 is_training=True
时初始化批量统计信息,因此我们必须在使用 hk.BatchNorm
层初始化 Haiku 模型的参数时设置 training=True
。在 Linen 中,我们可以像往常一样设置 training=False
。
在 NNX 中,参数和状态在模块实例化时已经初始化。批量统计信息属于 nnx.BatchStat
类,该类是 nnx.Variable
类的子类(不是 nnx.Param
,因为它们不是可学习的参数)。调用不带任何其他过滤参数的 nnx.split
将默认返回包含所有 nnx.Variable
的状态。
sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jnp.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
...
graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)
现在,Haiku 和 Linen 中的训练看起来非常相似,因为您使用相同的 apply
方法来运行正向传递。在 Haiku 中,现在将 batch_stats
作为第二个参数传递给 apply
,并将新更新的 batch_stats
作为第二个返回值。在 Linen 中,您改为将 batch_stats
添加为输入字典中的一个新键,并将 updates
变量字典作为第二个返回值。要更新批量统计信息,我们必须将 training=True
传递给 apply
。
在 NNX 中,训练代码与之前的示例相同,因为批次统计信息(与有状态的 NNX 模块绑定)是有状态更新的。为了在 NNX 中更新批次统计信息,我们首先调用 model.train()
,这将把 batchnorm 层的 use_running_average
属性设置为 False
(反之,调用 model.eval()
将把 use_running_average
设置为 True
)。由于有状态的 NNX 模块已经包含参数和批次统计信息,我们只需调用该模块来运行前向传播。我们使用 nnx.split
来提取可学习参数(所有可学习参数都是 NNX 类 nnx.Param
的子类),然后使用 nnx.update
应用梯度并有状态地更新模型。
...
@jax.jit
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params, batch_stats):
logits, batch_stats = model.apply(
params, batch_stats,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
...
@jax.jit
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params, batch_stats):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
model.train() # set use_running_average=False
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
) # batch statistics are updated statefully in this step
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss
grads = nnx.grad(loss_fn)(model)
_, params, _ = nnx.split(model, nnx.Param, ...)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, params)
为了使用 TrainState
,我们子类化以添加一个可以存储批次统计信息的额外字段。
...
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, batch_stats = state.apply_fn(
params, batch_stats,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
...
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats'
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates['batch_stats']
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
model.train() # set deterministic=False
graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat)
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
batch_stats=batch_stats,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, batch_stats):
logits, (graphdef, new_state) = state.apply_fn(
params, batch_stats
)(inputs) # <== inputs
_, batch_stats = new_state.split(nnx.Param, nnx.BatchStat)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, batch_stats
grads, batch_stats = jax.grad(
loss_fn, has_aux=True
)(state.params, state.batch_stats)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=batch_stats)
return state
使用多种方法#
在本节中,我们将看看如何在所有三个框架中使用多种方法。作为示例,我们将实现一个具有三种方法的自动编码器模型:encode
、decode
和 __call__
。
与之前一样,我们定义编码器和解码器层,无需传入输入形状,因为模块参数将在 Haiku 和 Linen 中使用形状推断进行延迟初始化。在 NNX 中,我们必须传入输入形状,因为模块参数将在没有形状推断的情况下被急切地初始化。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
与之前一样,我们在实例化 NNX 模块时传入输入形状。
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
...
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
对于 Haiku 和 Linen,init
可用于触发 __call__
方法以初始化模型的参数,该模型同时使用 encode
和 decode
方法。这将创建模型所需的所有必要参数。在 NNX 中,参数在模块实例化时已初始化。
params = model.init(
random.key(0),
x=jnp.ones((1, 784)),
)
params = model.init(
random.key(0),
x=jnp.ones((1, 784)),
)['params']
# parameters were already initialized during model instantiation
...
参数结构如下:
...
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
...
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
},
'encoder': {
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
最后,让我们探索如何执行前向传播。在 Haiku 和 Linen 中,我们使用 apply
函数来调用 encode
方法。在 NNX 中,我们可以直接调用 encode
方法。
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jnp.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jnp.ones((1, 784)),
method="encode",
)
...
z = model.encode(jnp.ones((1, 784)))
...
提升转换#
Flax 和 Haiku 都提供了一组转换,我们将其称为提升转换,它们以这样一种方式包装 JAX 转换:它们可以与模块一起使用,有时还可以提供额外的功能。在本节中,我们将看看如何在 Flax 和 Haiku 中使用 scan
的提升版本来实现一个简单的 RNN 层。
首先,我们将定义一个 RNNCell
模块,它将包含 RNN 单步逻辑。我们还将定义一个 initial_state
方法,该方法将用于初始化 RNN 的状态(也称为 carry
)。与 jax.lax.scan
一样,RNNCell.__call__
方法将是一个接收 carry 和输入并返回新 carry 和输出的函数。在这种情况下,carry 和输出相同。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN
模块,它将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell
,然后用它来构造 carry
,最后使用 hk.scan
在输入序列上运行 RNNCell
。
在 Linen 中,我们将使用 nn.scan
定义一个新的临时类型,它将包装 RNNCell
。在此过程中,我们还将指定指示 nn.scan
广播 params
集合(所有步骤共享相同的参数)并且不拆分 params
rng 流(以便所有步骤都使用相同的参数初始化),最后我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出沿第二个轴堆叠起来。然后我们将立即使用这个临时类型来创建一个提升的 RNNCell
实例,并用它来创建 carry
,并运行 __call__
方法,它将在序列上进行 scan
。
在 NNX 中,我们定义了一个 scan 函数 scan_fn
,它将使用在 __init__
中定义的 RNNCell
来扫描序列。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(
cell, carry,
jnp.swapaxes(x, 1, 0)
)
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
通常,Flax 和 Haiku 之间的提升转换的主要区别在于,在 Haiku 中,提升转换不操作状态,也就是说,Haiku 将处理 params
和 state
,使其在转换内部和外部保持相同的形状。在 Flax 中,提升转换可以同时操作变量集合和 rng 流,用户必须根据转换的语义定义每个转换如何处理不同的集合。
与之前示例一样,参数必须通过 .init()
初始化并传递到 .apply()
以在 Haiku 和 Linen 中进行前向传播。在 NNX 中,参数已被急切地初始化并绑定到有状态模块,并且该模块可以简单地调用输入以进行前向传播。
x = jnp.ones((3, 12, 32))
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jnp.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jnp.ones((3, 12, 32)),
)
x = jnp.ones((3, 12, 32))
model = RNN(64)
params = model.init(
random.key(0),
x=jnp.ones((3, 12, 32)),
)['params']
y = model.apply(
{'params': params},
x=jnp.ones((3, 12, 32)),
)
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
...
与前几节中的示例相比,唯一值得注意的变化是,这次我们在 Haiku 中使用了 hk.without_apply_rng
,因此我们不必将 rng
参数作为 None
传递到 apply
方法。
扫描层#
scan
的一个非常重要的应用是,将一系列层迭代地应用于输入,并将每层的输出作为下一层的输入。这对于减少大型模型的编译时间非常有用。作为示例,我们将创建一个简单的 Block
模块,然后在将 Block
模块应用 num_layers
次的 MLP
模块中使用它。
在 Haiku 中,我们按照通常的方式定义 Block
模块,然后在 MLP
内部,我们将使用 hk.experimental.layer_stack
在一个 stack_block
函数上创建一个 Block
模块堆栈。
在 Linen 中,Block
的定义略有不同,__call__
将接受并返回一个第二个虚拟输入/输出,在这两种情况下都将是 None
。在 MLP
中,我们将使用 nn.scan
,如前一个示例中一样,但通过设置 split_rngs={'params': True}
和 variable_axes={'params': 0}
,我们告诉 nn.scan
为每个步骤创建不同的参数,并在第一个轴上切片 params
集合,有效地实现了一个 Block
模块堆栈,如 Haiku 中所示。
在 NNX 中,我们使用 nnx.Scan.constructor()
定义一个 Block
模块堆栈。然后我们可以简单地调用 Block
的堆栈 self.blocks
,在输入和 carry 上,以获得前向传播输出。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array, _):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x, None
class MLP(nnx.Module):
def __init__(self, input_dim, features, num_layers, rngs):
self.blocks = nnx.Scan.constructor(
Block, length=num_layers
)(input_dim, features, rngs=rngs)
def __call__(self, x):
y, _ = self.blocks(x, None)
return y
请注意,在 Flax 中,我们将 None
作为第二个参数传递给 ScanBlock
并忽略其第二个输出。这些表示每步的输入/输出,但它们是 None
,因为在这种情况下我们没有任何输入/输出。
初始化每个模型与之前的示例相同。在这种情况下,我们将指定要使用 5
个层,每个层都有 64
个特征。与之前一样,我们还为 NNX 传入输入形状。
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
model = MLP(64, num_layers=5)
sample_x = jnp.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)['params']
...
model = MLP(64, 64, num_layers=5, rngs=nnx.Rngs(0))
...
使用层扫描时,您应该注意的一件事是,所有层都被融合成单个层,其参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...)
开头,因为我们使用了 5
个层。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
...
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
'blocks': {
'scan_module': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
}
})
顶层 Haiku 函数与顶层 Flax 模块#
在 Haiku 中,可以使用原始的 hk.{get,set}_{parameter,state}
来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数非常常见。
Flax 团队推荐使用更以模块为中心的方案,该方案使用 __call__
来定义前向函数。在 Linen 中,相应的访问器将是 Module.param
和 Module.variable
(有关集合的解释,请参见处理状态)。在 NNX 中,可以使用常规的 Python 类语义来设置和访问参数和变量。
...
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter(
'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jnp.ones((1, 64)))
...
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param(
'multiplier', nn.initializers.ones_init(), [1,], x.dtype
)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jnp.ones((1, 64)))
params, counter = variables['params'], variables['counter']
class Counter(nnx.Variable):
pass
class FooModule(nnx.Module):
def __init__(self, rngs):
self.counter = Counter(jnp.ones((), jnp.int32))
self.multiplier = nnx.Param(
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
self.counter.value += 1
return output
model = FooModule(rngs=nnx.Rngs(0))
_, params, counter = nnx.split(model, nnx.Param, Counter)