(l-onnx-doc-Attention)= # Attention (l-onnx-op-attention-23)= ## Attention - 23 ### Version - **name**: [Attention (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Attention) - **domain**: `main` - **since_version**: `23` - **function**: `True` - **support_level**: `SupportType.COMMON` - **shape inference**: `True` This version of the operator has been available **since version 23**. ### Summary Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed. This operator covers self and cross variants of the attention operation based on sequence lengths of K, Q and V. For self attention, `kv_sequence_length` equals to `q_sequence_length`. For cross attention, query and key might have different lengths. This operator also covers the 3 following variants based on the number of heads: 1) Multi-headed Attention (MHA): Described in the paper https://arxiv.org/pdf/1706.03762, `q_num_heads = kv_num_heads`. 2) Group-query Attention (GQA): Described in the paper https://arxiv.org/pdf/2305.13245, `q_num_heads > kv_num_heads`, `q_num_heads % kv_num_heads == 0`. 3) Multi-query Attention (MQA): Described in the paper https://arxiv.org/pdf/1911.02150, `q_num_heads > kv_num_heads`, `kv_num_heads=1`. Attention bias to be added is calculated based on `attn_mask` input and `is_causal attribute`, only one of which can be provided. 1) If `is_causal` is set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment. 2) `attn_mask`: A boolean mask where a value of `True` indicates that the element should take part in attention or a float mask of the same type as query, key, value that is added to the attention score. Both past and present state key/values are optional. They shall be used together, and not allowed to use only one of them. The following pattern is applied to the Q, K and V inputs after appropriate reshaping of K and V inputs based on sequence lengths and num heads provided: ``` The following pattern is applied by this operator: Q K V | | | Q*scale K*scale | | | | | Transpose | | | | ---MatMul--- | | | at_mask---Add | | | softcap (if provided) | | | Softmax | | | -----MatMul------ | Y ``` ### Attributes * **is_causal - INT** (default is `'0'`): If set to `1`, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment. * **kv_num_heads - INT** : Number of heads of key and value. Must be used with 3D inputs of Q, K and V. * **q_num_heads - INT** : Number of heads of query. Must be used with 3D inputs of Q, K and V. * **qk_matmul_output_mode - INT** (default is `'0'`): If set to `0`, qk_matmul_output is the output of qk matmul. If set to `1`, qk_matmul_output includes the addition of the attention mask to the output of qk matmul. If set to `2`, qk_matmul_output is the output after the softcap operation. If set to `3`, qk_matmul_output is the output after the softmax operation. Default value is 0. * **scale - FLOAT** : Scaling factor applied. Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math. Default value is `1/sqrt(head_size)` * **softcap - FLOAT** (default is `'0.0'`): Softcap value for attention weights. Default value is 0. * **softmax_precision - INT** : The floating-point precision used in softmax computation. If softmax precision is not provided, the same precision as the input of softmax (Q and K) is used. ### Inputs Between 3 and 6 inputs. - **Q** (heterogeneous) - **T1**: Query tensor. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, head_size)` or 3D tensor with shape `(batch_size, q_sequence_length, q_hidden_size)`. For cases with a 3D input tensor, `q_hidden_size = q_num_heads * head_size` - **K** (heterogeneous) - **T1**: Key tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, head_size)` or 3D tensor with shape `(batch_size, kv_sequence_length, k_hidden_size)`. For cases with a 3D input tensor, `k_hidden_size = kv_num_heads * head_size` - **V** (heterogeneous) - **T2**: Value tensor. 4D tensor with shape `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size, kv_sequence_length, v_hidden_size)`. For cases with a 3D input tensor, `v_hidden_size = kv_num_heads * v_head_size` - **attn_mask** (optional, heterogeneous) - **U**: Attention mask. Shape must be broadcastable to 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where `total_sequence_length = past_sequence_length + kv_sequence_length.` Two types of masks are supported. A boolean mask where a value of `True` indicates that the element should take part in attention. Also supports a float mask of the same type as query, key, value that is added to the attention score. - **past_key** (optional, heterogeneous) - **T1**: past state cache for key with shape `(batch_size, kv_num_heads, past_sequence_length, head_size)` - **past_value** (optional, heterogeneous) - **T2**: past state cache for value with shape `(batch_size, kv_num_heads, past_sequence_length, v_head_size)` ### Outputs Between 1 and 4 outputs. - **Y** (heterogeneous) - **T1**: The output tensor . 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, v_head_size)` or 3D tensor with shape `(batch_size, q_sequence_length, hidden_size)`. For cases with a 3D input tensor, `hidden_size = q_num_heads * v_head_size` - **present_key** (optional, heterogeneous) - **T1**: Updated key cache with shape `(batch_size, kv_num_heads, total_sequence_length, head_size)` where `total_sequence_length = past_sequence_length + kv_sequence_length`. - **present_value** (optional, heterogeneous) - **T2**: Updated value cache with shape `(batch_size, kv_num_heads, total_sequence_length, v_head_size)` where `total_sequence_length = past_sequence_length + kv_sequence_length`. - **qk_matmul_output** (optional, heterogeneous) - **T1**: The output of QK matmul. 4D tensor with shape `(batch_size, q_num_heads, q_sequence_length, total_sequence_length)` where `total_sequence_length = past_sequence_length + kv_sequence_length`. ### Type Constraints * **T1** in ( `tensor(bfloat16)`, `tensor(double)`, `tensor(float)`, `tensor(float16)` ): Constrain Q and K inputs types to float tensors. * **T2** in ( `tensor(bfloat16)`, `tensor(double)`, `tensor(float)`, `tensor(float16)` ): Constrain V input types to float tensors. * **U** in ( `tensor(bfloat16)`, `tensor(bool)`, `tensor(double)`, `tensor(float)`, `tensor(float16)`, `tensor(int16)`, `tensor(int32)`, `tensor(int64)`, `tensor(int8)`, `tensor(uint16)`, `tensor(uint32)`, `tensor(uint64)`, `tensor(uint8)` ): Constrain output 'mask' types to boolean tensors and input types.