RMSNormalization¶
RMSNormalization - 23¶
Version¶
domain:
main
since_version:
23
function:
True
support_level:
SupportType.COMMON
shape inference:
True
This version of the operator has been available since version 23.
Summary¶
This is RMS normalization defined in ONNX as function as described in the paper https://arxiv.org/pdf/1910.07467. The overall computation can be split into two stages. The root mean squared norm is taken over the last D dimensions, where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape), the rms norm is computed over the last 2 dimensions of the input. The computation required by standardization can be described by the following equations.
XSquared = Mul(X, X)
XSquaredMean = ReduceMean<axes=normalized_axes>(XSquared)
MeanSquareEpsilon = Add(XSquaredMean, epsilon)
RMS = Sqrt(MeanSquareEpsilon)
Normalized = Div(X, RMS)
where normalized_axes
is [axis, ..., rank of X - 1]
. The variables RMS
stand for root mean square,
Depending on stash_type
attribute, the actual computation
must happen in different floating-point precision.
For example, if stash_type
is 1, this operator casts
all input variables to 32-bit float, perform the computation, and
finally cast Normalized
back to the original type of X
.
The second stage then scales the outcome of the first stage using:
Y= Mul(Normalized, Scale)
Let d[i]
indicate the i-th dimension of X
.
If X
’s shape is [d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]
,
the shape of RMS
is [d[0], ..., d[axis-1], 1, ..., 1]
.
Y
and X
have the same shape. This operator supports unidirectional broadcasting
(Scale
should be unidirectional broadcastable to tensor X
);
for more details please check Broadcasting in ONNX.
Attributes¶
axis - INT (default is
'-1'
):The first normalization dimension: normalization will be performed along dimensions axis : rank(inputs).
epsilon - FLOAT (default is
'1e-05'
):The epsilon value to use to avoid division by zero.
stash_type - INT (default is
'1'
):The floating-point precision used in stage one of the computation.
Inputs¶
X (heterogeneous) - T:
The input tensor to be normalized. In general, the shape is (D1, D2, … , Dn) for n-dimensional data, where the root mean squared norm is taken over the last D dimensions, D is determined by the axis attribute.
scale (heterogeneous) - V:
Scale tensor. Scale tensor shape should be broadcastable to the normalized shape.
Outputs¶
Y (heterogeneous) - V:
Output data tensor. Same shape as X
Type Constraints¶
T in (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
):Constrain input X type to float tensors.
V in (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
):Constrain output Y and scale type to float tensors.