标准化#

class flax.nnx.BatchNorm(*args, **kwargs)[source]#

BatchNorm 模块。

要计算输入上的批次规范并更新批次统计信息,请调用 train() 方法(或在构造函数或调用时传入 use_running_average=False)。

要使用存储的批次统计信息的运行平均值,请调用 eval() 方法(或在构造函数或调用时传入 use_running_average=True)。

示例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5,
...                       dtype=jnp.float32, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(6,)
  ),
  'mean': VariableState(
    type=BatchStat,
    value=(6,)
  ),
  'scale': VariableState(
    type=Param,
    value=(6,)
  ),
  'var': VariableState(
    type=BatchStat,
    value=(6,)
  )
})

>>> # calculate batch norm on input and update batch statistics
>>> layer.train()
>>> y = layer(x)
>>> batch_stats1 = nnx.state(layer, nnx.BatchStat)
>>> y = layer(x)
>>> batch_stats2 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats1['mean'].value != batch_stats2['mean'].value).all()
>>> assert (batch_stats1['var'].value != batch_stats2['var'].value).all()

>>> # use stored batch statistics' running average
>>> layer.eval()
>>> y = layer(x)
>>> batch_stats3 = nnx.state(layer, nnx.BatchStat)
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
num_features#

输入特征的数量。

use_running_average#

如果为 True,则使用存储的批次统计信息,而不是计算输入上的批次统计信息。

axis#

输入的特征或非批次轴。

momentum#

批次统计信息指数移动平均值的衰减率。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nn.relu)时,可以禁用此选项,因为缩放将由下一层完成。

bias_init#

偏差的初始化器,默认值为零。

scale_init#

比例的初始化器,默认值为一。

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参见 jax.pmap(默认值:None)。

axis_index_groups#

该命名轴内的轴索引组,表示要减少的设备子集(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批次归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为真,则使用更快的但数值稳定性较差的方差计算。

rngs#

rng 密钥。

__call__(x, use_running_average=None, *, mask=None)[source]#

使用批次统计信息对输入进行规范化。

参数
  • x – 要规范化的输入。

  • use_running_average – 如果为真,则使用存储的批次统计信息,而不是计算输入上的批次统计信息。传递给调用方法的 use_running_average 标志将优先于传递给构造函数的 use_running_average 标志。

返回值

规范化的输入(与输入形状相同)。

方法

class flax.nnx.LayerNorm(*args, **kwargs)[source]#

层规范化 (https://arxiv.org/abs/1607.06450).

LayerNorm 对批次中的每个给定示例独立地规范化层的激活,而不是像批次规范化那样跨批次进行规范化。即应用一种变换,该变换使每个示例内的平均激活接近 0,而激活标准差接近 1。

示例用法

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

输入特征的数量。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以比例(gamma)。当下一层是线性的(例如 nnx.relu)时,可以禁用此选项,因为缩放将由下一层完成。

bias_init#

偏差的初始化器,默认值为零。

scale_init#

比例的初始化器,默认值为一。

reduction_axes#

用于计算规范化统计信息的轴。

feature_axes#

用于学习偏差和缩放的特征轴。

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参见 jax.pmap(默认值:None)。这仅在模型跨设备进行细分时需要,即被规范化的数组在 pmap 中跨设备进行分片。

axis_index_groups#

该命名轴内的轴索引组,表示要减少的设备子集(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批次归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为真,则使用更快的但数值稳定性较差的方差计算。

rngs#

rng 密钥。

__call__(x, *, mask=None)[source]#

对输入应用层归一化。

参数

x – 输入

返回值

规范化的输入(与输入形状相同)。

方法

class flax.nnx.RMSNorm(*args, **kwargs)[source]#

RMS 层归一化 (https://arxiv.org/abs/1910.07467).

RMSNorm 对每个批次中的给定样本独立地归一化层激活,而不是像批量归一化一样跨批次归一化。与将均值重新居中为 0 并按激活的标准差归一化的层归一化不同,RMSNorm 根本不重新居中,而是按激活的均方根归一化。

示例用法

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer)
State({
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})

>>> y = layer(x)
num_features#

输入特征的数量。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)。

use_scale#

如果为 True,则乘以比例因子(gamma)。当下一层为线性层(例如 nn.relu)时,可以禁用此选项,因为比例因子将由下一层完成。

scale_init#

比例的初始化器,默认值为一。

reduction_axes#

用于计算规范化统计信息的轴。

feature_axes#

用于学习偏差和缩放的特征轴。

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的说明,请参见 jax.pmap(默认值:None)。这仅在模型跨设备进行细分时需要,即被规范化的数组在 pmap 中跨设备进行分片。

axis_index_groups#

该命名轴内的轴索引组,表示要减少的设备子集(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批次归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为真,则使用更快的但数值稳定性较差的方差计算。

rngs#

rng 密钥。

__call__(x, mask=None)[source]#

对输入应用层归一化。

参数

x – 输入

返回值

规范化的输入(与输入形状相同)。

方法

class flax.nnx.GroupNorm(*args, **kwargs)[source]#

组归一化 (arxiv.org/abs/1803.08494).

此操作类似于批量归一化,但统计信息在大小相同的通道组之间共享,而不是在批次维度之间共享。因此,组归一化不依赖于批次组成,也不需要维护用于存储统计信息的内部状态。用户应指定通道组的总数或每个组的通道数。

注意

层归一化是组归一化的一个特例,其中 num_groups=1

示例用法

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': VariableState(
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState(
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
num_features#

输入特征/通道数。

num_groups#

通道组的总数。原始组归一化论文建议默认值为 32。

group_size#

一个组中的通道数。

epsilon#

添加到方差中的一个小浮点数,以避免除以零。

dtype#

结果的 dtype(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)。

use_bias#

如果为 True,则添加偏差(beta)。

use_scale#

如果为 True,则乘以比例因子(gamma)。当下一层为线性层(例如 nn.relu)时,可以禁用此选项,因为比例因子将由下一层完成。

bias_init#

偏差的初始化器,默认值为零。

scale_init#

比例的初始化器,默认值为一。

reduction_axes#

用于计算归一化统计信息的轴列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果在调用时使用的输入与用于初始化的数据相比具有额外的前导轴,例如由于批处理,则需要显式定义约简轴。

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参见 jax.pmap(默认值:无)。这仅在模型在设备之间细分时需要,即被归一化的数组在 pmap 或分片映射内的设备之间被分片。对于 SPMD jit,您不需要手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集体操作。

axis_index_groups#

该命名轴内的轴索引组,表示要减少的设备子集(默认值:None)。例如,[[0, 1], [2, 3]] 将独立地对前两个和最后两个设备上的示例进行批次归一化。有关更多详细信息,请参见 jax.lax.psum

use_fast_variance#

如果为真,则使用更快的但数值稳定性较差的方差计算。

rngs#

rng 密钥。

__call__(x, *, mask=None)[source]#

对输入应用组归一化 (arxiv.org/abs/1803.08494).

参数
  • x – 形状为 ...self.num_features 的输入,其中 self.num_features 是一个通道维度,而 ... 表示任意数量的额外维度,这些维度可用于累积统计信息。如果没有指定约简轴,则除了表示批次的第一个维度之外,所有额外维度 ... 将用于累积统计信息。

  • mask – 形状可广播到 inputs 张量的二进制数组,指示应计算均值和方差的位置。

返回值

规范化的输入(与输入形状相同)。

方法