RMSNormalization¶
RMSNormalization - 23¶
Version¶
domain:
mainsince_version:
23function:
Truesupport_level:
SupportType.COMMONshape inference:
True
This version of the operator has been available since version 23.
Summary¶
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467. The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions, where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape), the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be described by the following equations.
XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
MeanSquareEpsilon = Add(XSquaredMean, epsilon)
RMS = Sqrt(MeanSquareEpsilon)
Normalized = Div(X, RMS)
where normalized_axes is [axis, ..., rank of X - 1]. The variables RMS stand for root mean square,
Depending on stash_type attribute, the actual computation
must happen in different floating-point precision.
For example, if stash_type is 1, this operator casts
all input variables to 32-bit float, perform the computation, and
finally cast Normalized back to the original type of X.
The second stage then scales the outcome of the first stage using:
Y= Mul(Normalized, Scale)
Let d[i] indicate the i-th dimension of X.
If X’s shape is [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]],
the shape of RMS is [d[0], ..., d[axis-1], 1, ..., 1].
Y and X have the same shape. This operator supports unidirectional broadcasting
(Scale should be unidirectional broadcastable to tensor X);
for more details please check Broadcasting in ONNX.
Attributes¶
axis - INT (default is
'-1'):The first normalization dimension. If rank(X) is r, axis’ allowed range is [-r, r). Negative value means counting dimensions from the back.
epsilon - FLOAT (default is
'1e-05'):The epsilon value to use to avoid division by zero.
stash_type - INT (default is
'1'):The floating-point precision used in stage one of the computation.
Inputs¶
X (heterogeneous) - T:
The input tensor to be normalized. In general, the shape is (D1, D2, … , Dn) for n-dimensional data, where the root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.
scale (heterogeneous) - V:
Scale tensor. Scale tensor shape should be broadcastable to the normalized shape.
Outputs¶
Y (heterogeneous) - V:
Output data tensor. Same shape as X
Type Constraints¶
T in (
tensor(bfloat16),tensor(double),tensor(float),tensor(float16)):Constrain input X type to float tensors.
V in (
tensor(bfloat16),tensor(double),tensor(float),tensor(float16)):Constrain output Y and scale type to float tensors.