Linear#
NNX 线性层类。
- class flax.nnx.Conv(*args, **kwargs)[source]#
卷积模块,包装了
lax.conv_general_dilated
。示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
输入特征数量的整数或元组。
- out_features#
输出特征数量的整数或元组。
- kernel_size#
卷积核的形状。对于一维卷积,核大小可以作为整数传递,这将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。
- strides#
一个整数或
n
个整数的序列,代表窗口之间的步长(默认值:1)。
- padding#
字符串
'SAME'
、字符串'VALID'
、字符串'CIRCULAR'
(周期性边界条件)或n
个(low, high)
整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个 int 被解释为在所有维度上应用相同的填充,在序列中传递单个 int 会导致两侧使用相同的填充。'CAUSAL'
填充用于一维卷积将左填充卷积轴,从而导致相同大小的输出。
- input_dilation#
一个整数或
n
个整数的序列,指定在inputs
的每个空间维度中应用的膨胀因子(默认值:1)。使用输入膨胀因子d
的卷积等效于步长为d
的转置卷积。
- kernel_dilation#
一个整数或
n
个整数的序列,指定在卷积核的每个空间维度中应用的膨胀因子(默认值:1)。使用核膨胀因子的卷积也被称为“空洞卷积”。
- feature_group_count#
整数,默认值 1。如果指定,则将输入特征分成组。
- use_bias#
是否在输出中添加偏差(默认值:True)。
- mask#
在掩码卷积期间对权重进行掩码的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。
- kernel_init#
卷积核的初始化器。
- bias_init#
偏差的初始化器。
- rngs#
rng 密钥。
- __call__(inputs)[source]#
将(可能未共享)卷积应用于输入。
- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features)
。这是 channels-last 约定,例如,对于二维卷积,为 NHWC,对于三维卷积,为 NDHWC。注意:这与lax.conv_general_dilated
使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度都将被展平成一个维度进行卷积,并在返回之前恢复。在某些情况下,直接对层进行 vmap 可能比这种默认展平方法具有更好的性能。如果输入缺少批次维度,则会为卷积添加批次维度并在返回时删除,这种做法是为了能够编写单个示例代码。- 返回值
卷积后的数据。
方法
- class flax.nnx.ConvTranspose(*args, **kwargs)[source]#
卷积模块,包装了
lax.conv_transpose
。示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.value.shape (6, 6, 4, 3) >>> layer.bias.value.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
输入特征数量的整数或元组。
- out_features#
输出特征数量的整数或元组。
- kernel_size#
卷积核的形状。对于一维卷积,核大小可以作为整数传递,这将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。
- strides#
一个整数或
n
个整数的序列,代表窗口之间的步长(默认值:1)。
- padding#
字符串
'SAME'
、字符串'VALID'
、字符串'CIRCULAR'
(周期性边界条件)或n
个(low, high)
整数对的序列,用于指定在每个空间维度之前和之后应用的填充。单个 int 被解释为在所有维度上应用相同的填充,在序列中传递单个 int 会导致两侧使用相同的填充。'CAUSAL'
填充用于一维卷积将左填充卷积轴,从而导致相同大小的输出。
- kernel_dilation#
一个整数或
n
个整数的序列,指定在卷积核的每个空间维度中应用的膨胀因子(默认值:1)。使用核膨胀因子的卷积也被称为“空洞卷积”。
- use_bias#
是否在输出中添加偏差(默认值:True)。
- mask#
在掩码卷积期间对权重进行掩码的可选掩码。掩码必须与卷积权重矩阵具有相同的形状。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。
- kernel_init#
卷积核的初始化器。
- bias_init#
偏差的初始化器。
- transpose_kernel#
如果
True
则翻转空间轴并交换核的输入/输出通道轴。
- rngs#
rng 密钥。
- __call__(inputs)[source]#
将转置卷积应用于输入。
行为与
jax.lax.conv_transpose
相似。- 参数
inputs – 输入数据,维度为
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated
, which puts the spatial dimensions last. Note: 如果输入具有超过一个批处理维度,则所有批处理维度将被展平为单个维度以进行卷积,并在返回之前恢复。在某些情况下,直接将层 vmap'ing 可能比这种默认展平方法产生更好的性能。如果输入缺少批处理维度,则将为卷积添加它,并在返回时删除它,允许编写单个示例代码。- 返回值
卷积后的数据。
方法
- class flax.nnx.Embed(*args, **kwargs)[source]#
嵌入模块。
示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': VariableState( type=Param, value=Array([[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ], [-0.748139 , -0.15856352, 0.06061118], [-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ]], [[-0.4769059 , -0.6607095 , 0.46697947], [-0.748139 , -0.15856352, 0.06061118], [ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
一个从整数 [0,
num_embeddings
) 到features
维向量参数化函数。此Module
将创建一个形状为(num_embeddings, features)
的embedding
矩阵。在调用此层时,输入值将用于 0 索引到embedding
矩阵中。索引大于或等于num_embeddings
的值会导致nan
值。当num_embeddings
等于 1 时,它将使用附加的features
维度将embedding
矩阵广播到输入形状。- num_embeddings#
嵌入数量 / 词汇量大小。
- features#
每个嵌入的特征维度数量。
- dtype#
嵌入向量的 dtype(默认:与嵌入相同)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- embedding_init#
嵌入初始化器。
- rngs#
rng 密钥。
- __call__(inputs)[source]#
沿着最后一个维度嵌入输入。
- 参数
inputs – 输入数据,所有维度都被视为批处理维度。输入数组中的值必须是整数。
- 返回值
输出是嵌入的输入数据。输出形状遵循输入,并附加一个额外的
features
维度。
- attend(query)[source]#
使用查询数组关注嵌入。
- 参数
query – 最后一个维度等于嵌入的特征深度
features
的数组。- 返回值
一个最终维度为
num_embeddings
的数组,对应于查询向量数组与每个嵌入的批处理内积。通常用于 NLP 模型中嵌入和 logits 变换之间的权重共享。
方法
attend
(query)使用查询数组关注嵌入。
- class flax.nnx.Linear(*args, **kwargs)[source]#
应用于输入最后一个维度的线性变换。
示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(4,) ), 'kernel': VariableState( type=Param, value=(3, 4) ) })
- in_features#
输入特征的数量。
- out_features#
输出特征的数量。
- use_bias#
是否在输出中添加偏差(默认值:True)。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏差的初始化函数。
- dot_general#
点积函数。
- rngs#
rng 密钥。
方法
- class flax.nnx.LinearGeneral(*args, **kwargs)[source]#
具有灵活轴的线性变换。
示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4, 5) >>> layer.bias.value.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 3, 4, 5) >>> layer.bias.value.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- in_features#
输入特征数量的整数或元组。
- out_features#
输出特征数量的整数或元组。
- axis#
要应用变换的轴的 int 或元组。例如,(-2, -1) 将将变换应用于最后两个轴。
- batch_axis#
批处理轴索引到轴大小的映射。
- use_bias#
是否在输出中添加偏差(默认值:True)。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏差的初始化函数。
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。
- rngs#
rng 密钥。
方法
- class flax.nnx.Einsum(*args, **kwargs)[source]#
具有可学习内核和偏差的 einsum 变换。
示例用法
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (8, 2, 4) >>> layer.bias.value.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- einsum_str#
一个字符串,用于表示 einsum 方程。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的
einsum_str
必须且仅有一个为非 None,另一个必须为 None。
- kernel_shape#
核的形状。
- bias_shape#
偏置的形状。如果为 None,则不会使用偏置。
- dtype#
计算的数据类型(默认值:从输入和参数推断)。
- param_dtype#
传递给参数初始化器的数据类型(默认值:float32)。
- precision#
计算的数值精度,有关详细信息,请参阅
jax.lax.Precision
。
- kernel_init#
权重矩阵的初始化函数。
- bias_init#
偏差的初始化函数。
- rngs#
rng 密钥。
- __call__(inputs, einsum_str=None)[source]#
沿着最后一个维度对输入应用线性变换。
- 参数
inputs – 要变换的 nd 数组。
einsum_str – 一个字符串,用于表示 einsum 方程。该方程必须恰好有两个操作数,左侧是传入的输入,右侧是可学习的核。构造函数参数和调用参数中的
einsum_str
必须且仅有一个为非 None,另一个必须为 None。
- 返回值
变换后的输入。
方法