Attention#
- class flax.nnx.MultiHeadAttention(*args, **kwargs)[source]#
多头注意力。
示例用法
>>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all()
- num_heads#
注意头的数量。 特征(即 inputs_q.shape[-1])应可被头的数量整除。
- in_features#
输入特征数量的整数或元组。
- qkv_features#
键、查询和值的维度。
- out_features#
最后一个投影的维度
- dtype#
计算的 dtype(默认值:从输入和参数推断)
- param_dtype#
传递给参数初始化器的 dtype(默认值:float32)
- broadcast_dropout#
bool:在批次维度上使用广播 dropout。
- dropout_rate#
dropout 率
- deterministic#
如果为 false,则注意力权重使用 dropout 随机掩蔽,而如果为 true,则注意力权重是确定的。
- precision#
计算的数值精度,详情请参见 jax.lax.Precision。
- kernel_init#
Dense 层内核的初始化器。
- out_kernel_init#
输出 Dense 层内核的可选初始化器,如果为 None,则使用 kernel_init。
- bias_init#
Dense 层偏差的初始化器。
- out_bias_init#
输出 Dense 层偏差的可选初始化器,如果为 None,则使用 bias_init。
- use_bias#
bool:逐点 QKVO 稠密变换是否使用偏差。
- attention_fn#
dot_product_attention 或兼容函数。 接收查询、键、值,并返回形状为 [bs, dim1, dim2, …, dimN,, num_heads, value_channels]` 的输出
- decode#
是否准备和使用自回归缓存。
- normalize_qk#
是否应用 QK 归一化(arxiv.org/abs/2302.05442)。
- rngs#
rng 键。
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[source]#
对输入数据应用多头点积注意力。
将输入投影到多头查询、键和值向量,应用点积注意力,并将结果投影到输出向量。
如果 inputs_k 和 inputs_v 都为 None,则它们都将复制 inputs_q 的值(自注意力)。 如果只有 inputs_v 为 None,则它将复制 inputs_k 的值。
- 参数
inputs_q – 形状为 [batch_sizes…, length, features] 的输入查询。
inputs_k – 形状为 [batch_sizes…, length, features] 的键。 如果为 None,则 inputs_k 将复制 inputs_q 的值。
inputs_v – 形状为 [batch_sizes…, length, features] 的值。 如果为 None,则 inputs_v 将复制 inputs_k 的值。
mask – 形状为 [batch_sizes…, num_heads, query_length, key/value_length] 的注意力掩码。 如果对应掩码值为 False,则注意力权重将被掩蔽。
deterministic – 如果为 false,则注意力权重使用 dropout 随机掩蔽,而如果为 true,则注意力权重是确定的。 传递给调用方法的
deterministic
标志将优先于传递给构造函数的deterministic
标志。rngs – rng 键。 传递给调用方法的 rng 键将优先于传递给构造函数的 rng 键。
sow_weights – 如果为
True
,则注意力权重将被播种到 ‘intermediates’ 集合中。decode – 是否准备和使用自回归缓存。 传递给调用方法的
decode
标志将优先于传递给构造函数的decode
标志。
- 返回
形状为 [batch_sizes…, length, features] 的输出。
- init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#
初始化快速自回归解码的缓存。 当
decode=True
时,必须先调用此方法才能执行正向推断。示例用法
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(42) ... >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=rngs, ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x)
方法
init_cache
(input_shape[, dtype])初始化快速自回归解码的缓存。
- flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#
组合注意力掩码。
- 参数
*masks – 要组合的注意力掩码参数集,其中一些可以为 None。
dtype – 返回掩码的 dtype。
- 返回
组合的掩码,通过逻辑与进行缩减,如果没有给出掩码,则返回 None。
- flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#
计算给定查询、键和值的点积注意力。
这是基于 https://arxiv.org/abs/1706.03762 的注意力机制的核心函数。它根据查询和键计算注意力权重,并使用注意力权重组合值。
注意
query
、key
、value
不需要任何批次维度。- 参数
query – 用于计算注意力的查询,形状为
[batch..., q_length, num_heads, qk_depth_per_head]
。key – 用于计算注意力的键,形状为
[batch..., kv_length, num_heads, qk_depth_per_head]
。value – 用于注意力的值,形状为
[batch..., kv_length, num_heads, v_depth_per_head]
。bias – 注意力权重的偏差。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码、填充掩码、邻近偏差等。
mask – 注意力权重的掩码。它应该可以广播到形状 [batch…, num_heads, q_length, kv_length]。这可以用于合并因果掩码。如果相应的掩码值为 False,则注意力权重将被屏蔽。
broadcast_dropout – bool: 在批次维度上使用广播 dropout。
dropout_rng – JAX PRNGKey: 用于 dropout
dropout_rate – dropout 率
deterministic – bool,确定性或否(是否应用 dropout)
dtype – 计算的数据类型(默认:从输入中推断)
precision – 计算的数值精度,详见 jax.lax.Precision。
module – 将注意力权重播种到
nnx.Intermediate
集合中的模块。如果module
为 None,则不会播种注意力权重。
- 返回
输出形状为 [batch…, q_length, num_heads, v_depth_per_head]。
- flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
注意力权重的掩码生成辅助函数。
如果输入为 1d(即,[batch…, len_q],[batch…, len_kv]),则注意力权重将为 [batch…, heads, len_q, len_kv],此函数将生成 [batch…, 1, len_q, len_kv]。
- 参数
query_input – 查询长度大小的批次扁平输入
key_input – 键长度大小的批次扁平输入
pairwise_fn – 广播逐元素比较函数
extra_batch_dims – 要添加的额外批次维度的数量,默认为 0
dtype – 掩码返回数据类型
- 返回
一个形状为 [batch…, 1, len_q, len_kv] 的 1d 注意力掩码。
- flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
为自注意力创建因果掩码。
如果输入为 1d(即,[batch…, len]),则自注意力权重将为 [batch…, heads, len, len],此函数将生成形状为 [batch…, 1, len, len] 的因果掩码。
- 参数
x – 形状为 [batch…, len] 的输入数组
extra_batch_dims – 要添加的额外批次维度的数量,默认为 0
dtype – 掩码返回数据类型
- 返回
一个形状为 [batch…, 1, len, len] 的 1d 注意力因果掩码。