(l-onnx-docai-onnx-preview-FlexAttention)= # ai.onnx.preview - FlexAttention (l-onnx-opai-onnx-preview-flexattention-1)= ## FlexAttention - 1 (ai.onnx.preview) ### Version - **name**: [FlexAttention (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ai.onnx.preview.FlexAttention) - **domain**: `ai.onnx.preview` - **since_version**: `1` - **function**: `True` - **support_level**: `SupportType.EXPERIMENTAL` - **shape inference**: `True` No versioning maintained for experimental ops. ### Summary Computes scaled dot-product attention over rank-4 (batched, multi-head) inputs, with optional user-provided customization subgraphs at two stages: 1. score_mod: Modify the attention score tensor after Q·K^T 2. prob_mod: Modify the probability tensor after Softmax This operator mirrors the capabilities of PyTorch's flex_attention: https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html Input Shapes (MUST be rank-4 tensors): - Q: `(batch_size, q_num_heads, q_sequence_length, head_size)` - K: `(batch_size, kv_num_heads, kv_sequence_length, head_size)` - V: `(batch_size, kv_num_heads, kv_sequence_length, v_head_size)` Output Shape: - Y: `(batch_size, q_num_heads, q_sequence_length, v_head_size)` FlexAttention Computation: ``` Scores = (Q @ K^T) * scale Scores = score_mod(Scores) # if 'score_mod' is provided Probs = Softmax(Scores, axis=-1) Probs = prob_mod(Probs) # if 'prob_mod' is provided Y = Probs @ V ``` Grouped Query Attention (GQA): When `q_num_heads != kv_num_heads`, each K/V head is shared by a contiguous group of query heads in head-index order. Let `group_size = q_num_heads / kv_num_heads`; then query head `h` uses K/V head `floor(h / group_size)`. `q_num_heads` must be a multiple of `kv_num_heads`. Modifier Subgraphs (score_mod, prob_mod): Each modifier subgraph takes exactly one rank-4 tensor input and must produce exactly one rank-4 tensor output of the same shape and element type. - score_mod input/output shape: `(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)` - prob_mod input/output shape: `(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)` The element type is determined by softmax_precision (defaults to float32 for non-double inputs, otherwise double). Masking can be expressed in score_mod by writing masked positions as -inf (or a large negative value appropriate for the target precision). ### Attributes * **prob_mod - GRAPH** : Optional probability modifier subgraph with 1 rank-4 tensor input and 1 rank-4 tensor output of the same shape and element type: (probs) -> probs_out. probs has softmax_precision element type and shape (B, Hq, L, S). The output must preserve the input shape. * **scale - FLOAT** : Scaling factor for Q*K^T. Defaults to 1/sqrt(head_size). * **score_mod - GRAPH** : Optional score modifier subgraph with 1 rank-4 tensor input and 1 rank-4 tensor output of the same shape and element type: (scores) -> scores_out. scores has softmax_precision element type and shape (B, Hq, L, S). The output must preserve the input shape. * **softmax_precision - INT** : Floating-point precision for softmax computation. Defaults to float32 for non-double inputs, otherwise uses double. Must be explicitly specified for non-float types. ### Inputs - **Q** (heterogeneous) - **T1**: Query tensor with shape `(batch_size, q_num_heads, q_seq_len, head_size)`. - **K** (heterogeneous) - **T1**: Key tensor with shape `(batch_size, kv_num_heads, kv_seq_len, head_size)`. - **V** (heterogeneous) - **T1**: Value tensor with shape `(batch_size, kv_num_heads, kv_seq_len, v_head_size)`. ### Outputs - **Y** (heterogeneous) - **T1**: Output tensor with shape `(batch_size, q_num_heads, q_seq_len, v_head_size)`. ### Type Constraints * **T1** in ( `tensor(bfloat16)`, `tensor(double)`, `tensor(float)`, `tensor(float16)` ): Constrain Q, K, V to float tensors.