LinearAttention¶
LinearAttention - 27¶
Version¶
name: LinearAttention (GitHub)
domain:
mainsince_version:
27function:
Truesupport_level:
SupportType.COMMONshape inference:
True
This version of the operator has been available since version 27.
Summary¶
Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).
The query, key, value, and (where applicable) decay/beta inputs use 3D packed format [B, T, H*D], where heads are flattened into the last dimension; q_num_heads and kv_num_heads are always required and are used to unpack to 4D internally for computation. The optional past_state and present_state are 4D with shape (B, H_kv, d_k, d_v).
Group-query attention (GQA) is supported: q_num_heads must be a positive multiple of
kv_num_heads. When q_num_heads == kv_num_heads this reduces to multi-headed linear
attention; when q_num_heads > kv_num_heads each KV head (and its recurrent state) is
shared by q_num_heads / kv_num_heads query heads (multi-query attention is the
special case kv_num_heads == 1).
The update_rule attribute selects the recurrence type:
“linear”: S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
“gated”: S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
“delta”: S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t
“gated_delta”: S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t
where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.
Semantics: Equivalent to running the recurrent update sequentially for each token, but may be implemented using chunk-parallel algorithms for GPU efficiency.
Attributes¶
chunk_size - INT (default is
64):Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.
kv_num_heads - INT (required) :
Number of key/value heads. Always required.
q_num_heads - INT (required) :
Number of query heads. Always required.
scale - FLOAT (default is
0.0):Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.
update_rule - STRING (default is
gated_delta):The update rule for the linear attention recurrence. One of: ‘linear’, ‘gated’, ‘delta’, ‘gated_delta’. Default is ‘gated_delta’.
Inputs¶
Between 3 and 6 inputs.
query (heterogeneous) - T:
Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.
key (heterogeneous) - T:
Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.
value (heterogeneous) - T:
Value vectors with 3D packed shape (B, T, H_kv * d_v).
past_state (optional, heterogeneous) - S:
Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.
decay (optional, heterogeneous) - T:
Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for ‘gated’ and ‘gated_delta’ modes.
beta (optional, heterogeneous) - T:
Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for ‘delta’ and ‘gated_delta’ modes.
Outputs¶
output (heterogeneous) - T:
Attention output with 3D packed shape (B, T, H_q * d_v).
present_state (heterogeneous) - S:
Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.
Type Constraints¶
T in (
tensor(bfloat16),tensor(float),tensor(float16)):Constrain activation input and output types to float16, bfloat16, or float32 tensors.
S in (
tensor(bfloat16),tensor(float),tensor(float16)):Constrain state types to float16, bfloat16, or float32 tensors. Should be float32 or the same as T for numerical stability on long sequences.