(l-onnx-doc-LinearAttention)= # LinearAttention (l-onnx-op-linearattention-27)= ## LinearAttention - 27 ### Version - **name**: [LinearAttention (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#LinearAttention) - **domain**: `main` - **since_version**: `27` - **function**: `True` - **support_level**: `SupportType.COMMON` - **shape inference**: `True` This version of the operator has been available **since version 27**. ### Summary Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1). The query, key, value, and (where applicable) decay/beta inputs use 3D packed format [B, T, H*D], where heads are flattened into the last dimension; q_num_heads and kv_num_heads are always required and are used to unpack to 4D internally for computation. The optional past_state and present_state are 4D with shape (B, H_kv, d_k, d_v). Group-query attention (GQA) is supported: q_num_heads must be a positive multiple of kv_num_heads. When q_num_heads == kv_num_heads this reduces to multi-headed linear attention; when q_num_heads > kv_num_heads each KV head (and its recurrent state) is shared by `q_num_heads / kv_num_heads` query heads (multi-query attention is the special case kv_num_heads == 1). The update_rule attribute selects the recurrence type: - "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t - "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t - "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t - "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. Semantics: Equivalent to running the recurrent update sequentially for each token, but may be implemented using chunk-parallel algorithms for GPU efficiency. ### Attributes * **chunk_size - INT** (default is `64`): Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness. * **kv_num_heads - INT** (required) : Number of key/value heads. Always required. * **q_num_heads - INT** (required) : Number of query heads. Always required. * **scale - FLOAT** (default is `0.0`): Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override. * **update_rule - STRING** (default is `gated_delta`): The update rule for the linear attention recurrence. One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'. ### Inputs Between 3 and 6 inputs. - **query** (heterogeneous) - **T**: Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension. - **key** (heterogeneous) - **T**: Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes. - **value** (heterogeneous) - **T**: Value vectors with 3D packed shape (B, T, H_kv * d_v). - **past_state** (optional, heterogeneous) - **S**: Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros. - **decay** (optional, heterogeneous) - **T**: Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for 'gated' and 'gated_delta' modes. - **beta** (optional, heterogeneous) - **T**: Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for 'delta' and 'gated_delta' modes. ### Outputs - **output** (heterogeneous) - **T**: Attention output with 3D packed shape (B, T, H_q * d_v). - **present_state** (heterogeneous) - **S**: Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D. ### Type Constraints * **T** in ( `tensor(bfloat16)`, `tensor(float)`, `tensor(float16)` ): Constrain activation input and output types to float16, bfloat16, or float32 tensors. * **S** in ( `tensor(bfloat16)`, `tensor(float)`, `tensor(float16)` ): Constrain state types to float16, bfloat16, or float32 tensors. Should be float32 or the same as T for numerical stability on long sequences.