Attention#

class flax.nnx.MultiHeadAttention(*args, **kwargs)[source]#

多头注意力。

示例用法

>>> from flax import nnx
>>> import jax

>>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16,
...                                decode=False, rngs=nnx.Rngs(0))
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = (
...   jax.random.uniform(key1, shape),
...   jax.random.uniform(key2, shape),
...   jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer(q, k, v)
>>> # equivalent output when inferring v
>>> assert (layer(q, k) == layer(q, k, k)).all()
>>> # equivalent output when inferring k and v
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
num_heads#

注意头的数量。 特征(即 inputs_q.shape[-1])应可被头的数量整除。

in_features#

输入特征数量的整数或元组。

qkv_features#

键、查询和值的维度。

out_features#

最后一个投影的维度

dtype#

计算的 dtype(默认值:从输入和参数推断)

param_dtype#

传递给参数初始化器的 dtype(默认值:float32)

broadcast_dropout#

bool:在批次维度上使用广播 dropout。

dropout_rate#

dropout 率

deterministic#

如果为 false,则注意力权重使用 dropout 随机掩蔽,而如果为 true,则注意力权重是确定的。

precision#

计算的数值精度,详情请参见 jax.lax.Precision

kernel_init#

Dense 层内核的初始化器。

out_kernel_init#

输出 Dense 层内核的可选初始化器,如果为 None,则使用 kernel_init。

bias_init#

Dense 层偏差的初始化器。

out_bias_init#

输出 Dense 层偏差的可选初始化器,如果为 None,则使用 bias_init。

use_bias#

bool:逐点 QKVO 稠密变换是否使用偏差。

attention_fn#

dot_product_attention 或兼容函数。 接收查询、键、值,并返回形状为 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的输出

decode#

是否准备和使用自回归缓存。

normalize_qk#

是否应用 QK 归一化(arxiv.org/abs/2302.05442)。

rngs#

rng 键。

__call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[source]#

对输入数据应用多头点积注意力。

将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 都为 None,则它们都将复制 inputs_q 的值(自注意力)。 如果只有 inputs_v 为 None,则它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes…, length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes…, length, features] 的键。 如果为 None,则 inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes…, length, features] 的值。 如果为 None,则 inputs_v 将复制 inputs_k 的值。

  • mask – 形状为 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力掩码。 如果对应掩码值为 False,则注意力权重将被掩蔽。

  • deterministic – 如果为 false,则注意力权重使用 dropout 随机掩蔽,而如果为 true,则注意力权重是确定的。 传递给调用方法的 deterministic 标志将优先于传递给构造函数的 deterministic 标志。

  • rngs – rng 键。 传递给调用方法的 rng 键将优先于传递给构造函数的 rng 键。

  • sow_weights – 如果为 True,则注意力权重将被播种到 ‘intermediates’ 集合中。

  • decode – 是否准备和使用自回归缓存。 传递给调用方法的 decode 标志将优先于传递给构造函数的 decode 标志。

返回

形状为 [batch_sizes…, length, features] 的输出。

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#

初始化快速自回归解码的缓存。 当 decode=True 时,必须先调用此方法才能执行正向推断。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> rngs = nnx.Rngs(42)
...
>>> x = jnp.ones((1, 3))
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=rngs,
... )
...
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)

方法

init_cache(input_shape[, dtype])

初始化快速自回归解码的缓存。

flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#

组合注意力掩码。

参数
  • *masks – 要组合的注意力掩码参数集,其中一些可以为 None。

  • dtype – 返回掩码的 dtype。

返回

组合的掩码,通过逻辑与进行缩减,如果没有给出掩码,则返回 None。

flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#

计算给定查询、键和值的点积注意力。

这是基于 https://arxiv.org/abs/1706.03762 的注意力机制的核心函数。它根据查询和键计算注意力权重,并使用注意力权重组合值。

注意

querykeyvalue 不需要任何批次维度。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 用于注意力的值,形状为 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力权重的偏差。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码、填充掩码、邻近偏差等。

  • mask – 注意力权重的掩码。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码。如果相应的掩码值为 False,则注意力权重将被屏蔽。

  • broadcast_dropout – bool: 在批次维度上使用广播 dropout。

  • dropout_rng – JAX PRNGKey: 用于 dropout

  • dropout_rate – dropout 率

  • deterministic – bool,确定性或否(是否应用 dropout)

  • dtype – 计算的数据类型(默认:从输入中推断)

  • precision – 计算的数值精度,详见 jax.lax.Precision

  • module – 将注意力权重播种到 nnx.Intermediate 集合中的模块。如果 module 为 None,则不会播种注意力权重。

返回

输出形状为 [batch…, q_length, num_heads, v_depth_per_head]

flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

注意力权重的掩码生成辅助函数。

如果输入为 1d(即,[batch…, len_q][batch…, len_kv]),则注意力权重将为 [batch…, heads, len_q, len_kv],此函数将生成 [batch…, 1, len_q, len_kv]

参数
  • query_input – 查询长度大小的批次扁平输入

  • key_input – 键长度大小的批次扁平输入

  • pairwise_fn – 广播逐元素比较函数

  • extra_batch_dims – 要添加的额外批次维度的数量,默认为 0

  • dtype – 掩码返回数据类型

返回

一个形状为 [batch…, 1, len_q, len_kv] 的 1d 注意力掩码。

flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

为自注意力创建因果掩码。

如果输入为 1d(即,[batch…, len]),则自注意力权重将为 [batch…, heads, len, len],此函数将生成形状为 [batch…, 1, len, len] 的因果掩码。

参数
  • x – 形状为 [batch…, len] 的输入数组

  • extra_batch_dims – 要添加的额外批次维度的数量,默认为 0

  • dtype – 掩码返回数据类型

返回

一个形状为 [batch…, 1, len, len] 的 1d 注意力因果掩码。