线性

Linear#

NNX 线性层类。

class flax.nnx.Conv(*args, **kwargs)[source]#

卷积模块,包装了 lax.conv_general_dilated

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 6, 4)

>>> # circular padding with stride 2
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3),
...                  strides=2, padding='CIRCULAR', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 4, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

输入特征数量的整数或元组。

out_features#

输出特征数量的整数或元组。

kernel_size#

卷积核的形状。对于一维卷积,核大小可以作为整数传递,这将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

strides#

一个整数或 n 个整数的序列,代表窗口之间的步长(默认值:1)。

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件)或 n(low, high) 整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个 int 被解释为在所有维度上应用相同的填充,在序列中传递单个 int 会导致两侧使用相同的填充。'CAUSAL' 填充用于一维卷积将左填充卷积轴,从而导致相同大小的输出。

input_dilation#

一个整数或 n 个整数的序列,指定在 inputs 的每个空间维度中应用的膨胀因子(默认值:1)。使用输入膨胀因子 d 的卷积等效于步长为 d 的转置卷积。

kernel_dilation#

一个整数或 n 个整数的序列,指定在卷积核的每个空间维度中应用的膨胀因子(默认值:1)。使用核膨胀因子的卷积也被称为“空洞卷积”。

feature_group_count#

整数,默认值 1。如果指定,则将输入特征分成组。

use_bias#

是否在输出中添加偏差(默认值:True)。

mask#

在掩码卷积期间对权重进行掩码的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

kernel_init#

卷积核的初始化器。

bias_init#

偏差的初始化器。

rngs#

rng 密钥。

__call__(inputs)[source]#

将(可能未共享)卷积应用于输入。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是 channels-last 约定,例如,对于二维卷积,为 NHWC,对于三维卷积,为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度都将被展平成一个维度进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 可能比这种默认展平方法具有更好的性能。如果输入缺少批次维度,则会为卷积添加批次维度并在返回时删除,这种做法是为了能够编写单个示例代码。

返回值

卷积后的数据。

方法

class flax.nnx.ConvTranspose(*args, **kwargs)[source]#

卷积模块,包装了 lax.conv_transpose

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,),
...                           padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 10, 4)

>>> # circular padding with stride 2
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6),
...                           strides=(2, 2), padding='CIRCULAR',
...                           transpose_kernel=True, rngs=rngs)
>>> layer.kernel.value.shape
(6, 6, 4, 3)
>>> layer.bias.value.shape
(4,)
>>> out = layer(jnp.ones((1, 15, 15, 3)))
>>> out.shape
(1, 30, 30, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

输入特征数量的整数或元组。

out_features#

输出特征数量的整数或元组。

kernel_size#

卷积核的形状。对于一维卷积,核大小可以作为整数传递,这将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

strides#

一个整数或 n 个整数的序列,代表窗口之间的步长(默认值:1)。

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件)或 n(low, high) 整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个 int 被解释为在所有维度上应用相同的填充,在序列中传递单个 int 会导致两侧使用相同的填充。'CAUSAL' 填充用于一维卷积将左填充卷积轴,从而导致相同大小的输出。

kernel_dilation#

一个整数或 n 个整数的序列,指定在卷积核的每个空间维度中应用的膨胀因子(默认值:1)。使用核膨胀因子的卷积也被称为“空洞卷积”。

use_bias#

是否在输出中添加偏差(默认值:True)。

mask#

在掩码卷积期间对权重进行掩码的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

kernel_init#

卷积核的初始化器。

bias_init#

偏差的初始化器。

transpose_kernel#

如果 True 则翻转空间轴并交换核的输入/输出通道轴。

rngs#

rng 密钥。

__call__(inputs)[source]#

将转置卷积应用于输入。

行为与 jax.lax.conv_transpose 相似。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated, which puts the spatial dimensions last. Note: 如果输入具有超过一个批处理维度,则所有批处理维度将被展平为单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接将层 vmap'ing 可能比这种默认展平方法产生更好的性能。如果输入缺少批处理维度,则将为卷积添加它,并在返回时删除它,允许编写单个示例代码。

返回值

卷积后的数据。

方法

class flax.nnx.Embed(*args, **kwargs)[source]#

嵌入模块。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'embedding': VariableState(
    type=Param,
    value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
           [ 0.01070483,  0.27923733,  1.7487359 ],
           [ 0.59161806,  0.8660184 ,  1.2838588 ],
           [-0.748139  , -0.15856352,  0.06061118],
           [-0.4769059 , -0.6607095 ,  0.46697947]], dtype=float32)
  )
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
        [ 0.01070483,  0.27923733,  1.7487359 ],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]],

       [[-0.4769059 , -0.6607095 ,  0.46697947],
        [-0.748139  , -0.15856352,  0.06061118],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]]], dtype=float32)

一个从整数 [0, num_embeddings) 到 features 维向量参数化函数。此 Module 将创建一个形状为 (num_embeddings, features)embedding 矩阵。在调用此层时,输入值将用于 0 索引到 embedding 矩阵中。索引大于或等于 num_embeddings 的值会导致 nan 值。当 num_embeddings 等于 1 时,它将使用附加的 features 维度将 embedding 矩阵广播到输入形状。

num_embeddings#

嵌入数量 / 词汇量大小。

features#

每个嵌入的特征维度数量。

dtype#

嵌入向量的 dtype(默认:与嵌入相同)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

embedding_init#

嵌入初始化器。

rngs#

rng 密钥。

__call__(inputs)[source]#

沿着最后一个维度嵌入输入。

参数

inputs – 输入数据,所有维度都被视为批处理维度。输入数组中的值必须是整数。

返回值

输出是嵌入的输入数据。输出形状遵循输入,并附加一个额外的 features 维度。

attend(query)[source]#

使用查询数组关注嵌入。

参数

query – 最后一个维度等于嵌入的特征深度 features 的数组。

返回值

一个最终维度为 num_embeddings 的数组,对应于查询向量数组与每个嵌入的批处理内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。

方法

attend(query)

使用查询数组关注嵌入。

class flax.nnx.Linear(*args, **kwargs)[source]#

应用于输入最后一个维度的线性变换。

示例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(4,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(3, 4)
  )
})
in_features#

输入特征的数量。

out_features#

输出特征的数量。

use_bias#

是否在输出中添加偏差(默认值:True)。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏差的初始化函数。

dot_general#

点积函数。

rngs#

rng 密钥。

__call__(inputs)[source]#

沿着最后一个维度对输入应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回值

变换后的输入。

方法

class flax.nnx.LinearGeneral(*args, **kwargs)[source]#

具有灵活轴的线性变换。

示例用法

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # equivalent to `nnx.Linear(2, 4)`
>>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4)
>>> # output features (4, 5)
>>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> y = layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)
in_features#

输入特征数量的整数或元组。

out_features#

输出特征数量的整数或元组。

axis#

要应用变换的轴的 int 或元组。例如,(-2, -1) 将将变换应用于最后两个轴。

batch_axis#

批处理轴索引到轴大小的映射。

use_bias#

是否在输出中添加偏差(默认值:True)。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏差的初始化函数。

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

rngs#

rng 密钥。

__call__(inputs)[source]#

沿着多个维度对输入应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回值

变换后的输入。

方法

class flax.nnx.Einsum(*args, **kwargs)[source]#

具有可学习内核和偏差的 einsum 变换。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(8, 2, 4)
>>> layer.bias.value.shape
(8, 4)
>>> y = layer(jnp.ones((16, 11, 2)))
>>> y.shape
(16, 11, 8, 4)
einsum_str#

一个字符串,用于表示 einsum 方程。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的 einsum_str 必须且仅有一个为非 None,另一个必须为 None。

kernel_shape#

核的形状。

bias_shape#

偏置的形状。如果为 None,则不会使用偏置。

dtype#

计算的数据类型(默认值:从输入和参数推断)。

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

kernel_init#

权重矩阵的初始化函数。

bias_init#

偏差的初始化函数。

rngs#

rng 密钥。

__call__(inputs, einsum_str=None)[source]#

沿着最后一个维度对输入应用线性变换。

参数
  • inputs – 要变换的 nd 数组。

  • einsum_str – 一个字符串,用于表示 einsum 方程。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的 einsum_str 必须且仅有一个为非 None,另一个必须为 None。

返回值

变换后的输入。

方法