ai.onnx.preview - FlexAttention¶
FlexAttention - 1 (ai.onnx.preview)¶
Version¶
name: FlexAttention (GitHub)
domain:
ai.onnx.previewsince_version:
1function:
Truesupport_level:
SupportType.EXPERIMENTALshape inference:
True
No versioning maintained for experimental ops.
Summary¶
Computes scaled dot-product attention over rank-4 (batched, multi-head) inputs, with optional user-provided customization subgraphs at two stages:
score_mod: Modify the attention score tensor after Q·K^T
prob_mod: Modify the probability tensor after Softmax
This operator mirrors the capabilities of PyTorch’s flex_attention: https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html
Input Shapes (MUST be rank-4 tensors):
Q:
(batch_size, q_num_heads, q_sequence_length, head_size)K:
(batch_size, kv_num_heads, kv_sequence_length, head_size)V:
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
Output Shape:
Y:
(batch_size, q_num_heads, q_sequence_length, v_head_size)
FlexAttention Computation:
Scores = (Q @ K^T) * scale
Scores = score_mod(Scores) # if 'score_mod' is provided
Probs = Softmax(Scores, axis=-1)
Probs = prob_mod(Probs) # if 'prob_mod' is provided
Y = Probs @ V
Grouped Query Attention (GQA):
When q_num_heads != kv_num_heads, each K/V head is shared by a contiguous
group of query heads in head-index order. Let
group_size = q_num_heads / kv_num_heads; then query head h uses K/V head
floor(h / group_size). q_num_heads must be a multiple of
kv_num_heads.
Modifier Subgraphs (score_mod, prob_mod): Each modifier subgraph takes exactly one rank-4 tensor input and must produce exactly one rank-4 tensor output of the same shape and element type.
score_mod input/output shape:
(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)prob_mod input/output shape:
(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)The element type is determined by softmax_precision (defaults to float32 for non-double inputs, otherwise double).
Masking can be expressed in score_mod by writing masked positions as -inf (or a large negative value appropriate for the target precision).
Attributes¶
prob_mod - GRAPH :
Optional probability modifier subgraph with 1 rank-4 tensor input and 1 rank-4 tensor output of the same shape and element type: (probs) -> probs_out. probs has softmax_precision element type and shape (B, Hq, L, S). The output must preserve the input shape.
scale - FLOAT :
Scaling factor for Q*K^T. Defaults to 1/sqrt(head_size).
score_mod - GRAPH :
Optional score modifier subgraph with 1 rank-4 tensor input and 1 rank-4 tensor output of the same shape and element type: (scores) -> scores_out. scores has softmax_precision element type and shape (B, Hq, L, S). The output must preserve the input shape.
softmax_precision - INT :
Floating-point precision for softmax computation. Defaults to float32 for non-double inputs, otherwise uses double. Must be explicitly specified for non-float types.
Inputs¶
Q (heterogeneous) - T1:
Query tensor with shape
(batch_size, q_num_heads, q_seq_len, head_size).K (heterogeneous) - T1:
Key tensor with shape
(batch_size, kv_num_heads, kv_seq_len, head_size).V (heterogeneous) - T1:
Value tensor with shape
(batch_size, kv_num_heads, kv_seq_len, v_head_size).
Outputs¶
Y (heterogeneous) - T1:
Output tensor with shape
(batch_size, q_num_heads, q_seq_len, v_head_size).
Type Constraints¶
T1 in (
tensor(bfloat16),tensor(double),tensor(float),tensor(float16)):Constrain Q, K, V to float tensors.