LinearAttention

LinearAttention - 27

Version

  • name: LinearAttention (GitHub)

  • domain: main

  • since_version: 27

  • function: True

  • support_level: SupportType.COMMON

  • shape inference: True

This version of the operator has been available since version 27.

Summary

Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).

The query, key, value, and (where applicable) decay/beta inputs use 3D packed format [B, T, H*D], where heads are flattened into the last dimension; q_num_heads and kv_num_heads are always required and are used to unpack to 4D internally for computation. The optional past_state and present_state are 4D with shape (B, H_kv, d_k, d_v).

Group-query attention (GQA) is supported: q_num_heads must be a positive multiple of kv_num_heads. When q_num_heads == kv_num_heads this reduces to multi-headed linear attention; when q_num_heads > kv_num_heads each KV head (and its recurrent state) is shared by q_num_heads / kv_num_heads query heads (multi-query attention is the special case kv_num_heads == 1).

The update_rule attribute selects the recurrence type:

  • “linear”: S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t

  • “gated”: S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t

  • “delta”: S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t

  • “gated_delta”: S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t

where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.

Semantics: Equivalent to running the recurrent update sequentially for each token, but may be implemented using chunk-parallel algorithms for GPU efficiency.

Attributes

  • chunk_size - INT (default is 64):

    Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.

  • kv_num_heads - INT (required) :

    Number of key/value heads. Always required.

  • q_num_heads - INT (required) :

    Number of query heads. Always required.

  • scale - FLOAT (default is 0.0):

    Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.

  • update_rule - STRING (default is gated_delta):

    The update rule for the linear attention recurrence. One of: ‘linear’, ‘gated’, ‘delta’, ‘gated_delta’. Default is ‘gated_delta’.

Inputs

Between 3 and 6 inputs.

  • query (heterogeneous) - T:

    Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.

  • key (heterogeneous) - T:

    Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.

  • value (heterogeneous) - T:

    Value vectors with 3D packed shape (B, T, H_kv * d_v).

  • past_state (optional, heterogeneous) - S:

    Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.

  • decay (optional, heterogeneous) - T:

    Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for ‘gated’ and ‘gated_delta’ modes.

  • beta (optional, heterogeneous) - T:

    Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for ‘delta’ and ‘gated_delta’ modes.

Outputs

  • output (heterogeneous) - T:

    Attention output with 3D packed shape (B, T, H_q * d_v).

  • present_state (heterogeneous) - S:

    Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.

Type Constraints

  • T in ( tensor(bfloat16), tensor(float), tensor(float16) ):

    Constrain activation input and output types to float16, bfloat16, or float32 tensors.

  • S in ( tensor(bfloat16), tensor(float), tensor(float16) ):

    Constrain state types to float16, bfloat16, or float32 tensors. Should be float32 or the same as T for numerical stability on long sequences.