Attention¶
Attention - 23¶
Version¶
name: Attention (GitHub)
domain:
main
since_version:
23
function:
True
support_level:
SupportType.COMMON
shape inference:
True
This version of the operator has been available since version 23.
Summary¶
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed.
This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V.
For self attention, kv_sequence_length
equals to q_sequence_length
.
For cross attention, query and key might have different lengths.
This operator also covers the 3 following variants based on the number of heads:
Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762,
q_num_heads = kv_num_heads
.Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245,
q_num_heads > kv_num_heads
,q_num_heads % kv_num_heads == 0
.Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150,
q_num_heads > kv_num_heads
,kv_num_heads=1
.
Attention bias to be added is calculated based on attn_mask
input and is_causal attribute
, only one of which can be provided.
If
is_causal
is set to1
, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.attn_mask
: A boolean mask where a value ofTrue
indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score.
Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided:
The following pattern is applied by this operator:
Q K V
| | |
Q*scale K*scale |
| | |
| Transpose |
| | |
---MatMul--- |
| |
at_mask---Add |
| |
softcap (if provided) |
| |
Softmax |
| |
-----MatMul------
|
Y
Attributes¶
is_causal - INT (default is
'0'
):If set to
1
, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment.kv_num_heads - INT :
Number of heads of key and value. Must be used with 3D inputs of Q, K and V.
q_num_heads - INT :
Number of heads of query. Must be used with 3D inputs of Q, K and V.
qk_matmul_output_mode - INT (default is
'0'
):If set to
0
, qk_matmul_output is the output of qk matmul. If set to1
, qk_matmul_output includes the addition of the attention mask to the output of qk matmul. If set to2
, qk_matmul_output is the output after the softcap operation. If set to3
, qk_matmul_output is the output after the softmax operation. Default value is 0.scale - FLOAT :
Scaling factor applied. Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math. Default value is
1/sqrt(head_size)
softcap - FLOAT (default is
'0.0'
):Softcap value for attention weights. Default value is 0.
softmax_precision - INT :
The floating-point precision used in softmax computation. If softmax precision is not provided, the same precision as the input of softmax (Q and K) is used.
Inputs¶
Between 3 and 6 inputs.
Q (heterogeneous) - T1:
Query tensor. 4D tensor with shape
(batch_size, q_num_heads, q_sequence_length, head_size)
or 3D tensor with shape(batch_size, q_sequence_length, q_hidden_size)
. For cases with a 3D input tensor,q_hidden_size = q_num_heads * head_size
K (heterogeneous) - T1:
Key tensor. 4D tensor with shape
(batch_size, kv_num_heads, kv_sequence_length, head_size)
or 3D tensor with shape(batch_size, kv_sequence_length, k_hidden_size)
. For cases with a 3D input tensor,k_hidden_size = kv_num_heads * head_size
V (heterogeneous) - T2:
Value tensor. 4D tensor with shape
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
or 3D tensor with shape(batch_size, kv_sequence_length, v_hidden_size)
. For cases with a 3D input tensor,v_hidden_size = kv_num_heads * v_head_size
attn_mask (optional, heterogeneous) - U:
Attention mask. Shape must be broadcastable to 4D tensor with shape
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
wheretotal_sequence_length = past_sequence_length + kv_sequence_length.
Two types of masks are supported. A boolean mask where a value ofTrue
indicates that the element should take part in attention. Also supports a float mask of the same type as query, key, value that is added to the attention score.past_key (optional, heterogeneous) - T1:
past state cache for key with shape
(batch_size, kv_num_heads, past_sequence_length, head_size)
past_value (optional, heterogeneous) - T2:
past state cache for value with shape
(batch_size, kv_num_heads, past_sequence_length, v_head_size)
Outputs¶
Between 1 and 4 outputs.
Y (heterogeneous) - T1:
The output tensor . 4D tensor with shape
(batch_size, q_num_heads, q_sequence_length, v_head_size)
or 3D tensor with shape(batch_size, q_sequence_length, hidden_size)
. For cases with a 3D input tensor,hidden_size = q_num_heads * v_head_size
present_key (optional, heterogeneous) - T1:
Updated key cache with shape
(batch_size, kv_num_heads, total_sequence_length, head_size)
wheretotal_sequence_length = past_sequence_length + kv_sequence_length
.present_value (optional, heterogeneous) - T2:
Updated value cache with shape
(batch_size, kv_num_heads, total_sequence_length, v_head_size)
wheretotal_sequence_length = past_sequence_length + kv_sequence_length
.qk_matmul_output (optional, heterogeneous) - T1:
The output of QK matmul. 4D tensor with shape
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
wheretotal_sequence_length = past_sequence_length + kv_sequence_length
.
Type Constraints¶
T1 in (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
):Constrain Q and K inputs types to float tensors.
T2 in (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
):Constrain V input types to float tensors.
U in (
tensor(bfloat16)
,tensor(bool)
,tensor(double)
,tensor(float)
,tensor(float16)
,tensor(int16)
,tensor(int32)
,tensor(int64)
,tensor(int8)
,tensor(uint16)
,tensor(uint32)
,tensor(uint64)
,tensor(uint8)
):Constrain output ‘mask’ types to boolean tensors and input types.