(l-onnx-doc-TensorScatter)= # TensorScatter (l-onnx-op-tensorscatter-24)= ## TensorScatter - 24 ### Version - **name**: [TensorScatter (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#TensorScatter) - **domain**: `main` - **since_version**: `24` - **function**: `False` - **support_level**: `SupportType.COMMON` - **shape inference**: `True` This version of the operator has been available **since version 24**. ### Summary TensorScatter is a generic tensor update operation, motivated by the requirements for KV cache updates for Attention ops commonly found in LLMs. It is a functional operation that models an in-place update to a KV cache buffer. The past and present cache tensors have the same shape (batch_size, D1, D2, ..., max_sequence_length, ..., Dn), with the sequence dimension (indicated by the `axis` attribute) being max_sequence_length, so the sizes of these tensors do not need to grow between iterations. The `update` tensor's shape only differs from the cache tensors in the sequence dimension: (batch_size, D1, D2, ..., sequence_length, ..., Dn), where sequence_length <= max_sequence_length. The optional `write_indices` input indicates the write index for each sample in the batch, assumed to be zero if not provided. When the `mode` attribute is set to "circular", the write index is modulo max_sequence_length. The operation can be described using the following pseudocode: ``` for prefix_idx in np.ndindex(past_cache.shape[:axis]): batch_idx = prefix_idx[0] for sequence_idx in range(sequence_length): cache_idx = (*prefix_idx, write_indices[batch_idx] + sequence_idx) if mode == "circular": cache_idx = tuple(np.mod(np.asarray(cache_idx), max_sequence_length)) update_idx = (*prefix_idx, sequence_idx) present_cache[cache_idx] = update[update_idx] ``` During the prefill phase of attention, only the first two inputs are needed. During the decode phase, `write_indices` is also needed so that the incoming key or value update can be appended after the last valid token for each sample in the batch. ### Attributes * **axis - INT** (default is `'-2'`): Sequence dimension of the `past_cache` and `update` tensors. It cannot be 0 (the batch dimension). Default is -2. * **mode - STRING** (default is `'linear'`): Write mode of cache update. Supported modes include `linear` and `circular`. `linear` mode requires write_indices+sequence_length<=max_sequence_length. For `circular` mode, the updates happen in wrap-around fashion, ie, the update index is modulo `max_sequence_length` ### Inputs Between 2 and 3 inputs. - **past_cache** (heterogeneous) - **T**: Past state cache for key or value with shape `(batch_size, D1, D2, ..., max_sequence_length, ..., Dn)`. - **update** (heterogeneous) - **T**: New update tensor with shape `(batch_size, D1, D2, ..., sequence_length, ..., Dn)`. - **write_indices** (optional, heterogeneous) - **tensor(int64)**: Write indices for the incoming update tensor in the cache. Shape is `(batch_size,)`. Assumed to be all zeros if not provided. ### Outputs - **present_cache** (heterogeneous) - **T**: Updated cache. Same shape as `past_cache`. ### Type Constraints * **T** in ( `tensor(bfloat16)`, `tensor(bool)`, `tensor(complex128)`, `tensor(complex64)`, `tensor(double)`, `tensor(float)`, `tensor(float16)`, `tensor(float4e2m1)`, `tensor(float8e4m3fn)`, `tensor(float8e4m3fnuz)`, `tensor(float8e5m2)`, `tensor(float8e5m2fnuz)`, `tensor(float8e8m0)`, `tensor(int16)`, `tensor(int32)`, `tensor(int4)`, `tensor(int64)`, `tensor(int8)`, `tensor(string)`, `tensor(uint16)`, `tensor(uint32)`, `tensor(uint4)`, `tensor(uint64)`, `tensor(uint8)` ): Constrain input and output types to any tensor type.