RMSNormalization

RMSNormalization - 23

Version

  • name: RMSNormalization (GitHub)

  • domain: main

  • since_version: 23

  • function: True

  • support_level: SupportType.COMMON

  • shape inference: True

This version of the operator has been available since version 23.

Summary

This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467. The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions, where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape), the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be described by the following equations.

XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
MeanSquareEpsilon = Add(XSquaredMean, epsilon)
RMS = Sqrt(MeanSquareEpsilon)
Normalized = Div(X, RMS)

where normalized_axes is [axis, ..., rank of X - 1]. The variables RMS stand for root mean square, Depending on stash_type attribute, the actual computation must happen in different floating-point precision. For example, if stash_type is 1, this operator casts all input variables to 32-bit float, perform the computation, and finally cast Normalized back to the original type of X. The second stage then scales the outcome of the first stage using:

Y= Mul(Normalized, Scale)

Let d[i] indicate the i-th dimension of X. If X’s shape is [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]], the shape of RMS is [d[0], ..., d[axis-1], 1, ..., 1]. Y and X have the same shape. This operator supports unidirectional broadcasting (Scale should be unidirectional broadcastable to tensor X); for more details please check Broadcasting in ONNX.

Attributes

  • axis - INT (default is '-1'):

    The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).

  • epsilon - FLOAT (default is '1e-05'):

    The epsilon value to use to avoid division by zero.

  • stash_type - INT (default is '1'):

    The floating-point precision used in stage one of the computation.

Inputs

  • X (heterogeneous) - T:

    The input tensor to be normalized. In general, the shape is (D1, D2, … , Dn) for n-dimensional data, where the root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.

  • scale (heterogeneous) - V:

    Scale tensor. Scale tensor shape should be broadcastable to the normalized shape.

Outputs

  • Y (heterogeneous) - V:

    Output data tensor. Same shape as X

Type Constraints

  • T in ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) ):

    Constrain input X type to float tensors.

  • V in ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) ):

    Constrain output Y and scale type to float tensors.