ONNX Shape Inference¶
ONNX provides an optional implementation of shape inference on ONNX graphs. This implementation covers each of the core operators, as well as provides an interface for extensibility. Therefore, you may choose to invoke the existing shape inference functionality on your graphs, or to define shape inference implementations to go along with your custom operators (or both!). Shape inference functions are stored as a member of the OpSchema objects.
In ONNX 1.10 release, symbol generation and propagation along with shape data propagation was added to ONNX graph level shape inference. Detailed proposal is here
Background¶
Please see this section of IR.md for a review of static tensor shapes.
In particular, a static tensor shape (represented by a TensorShapeProto) is distinct from
a runtime tensor shape. This feature is commonly used when the exact runtime tensor shape is
not known statically (that is, at compile time).
A
Tensorwith an undefinedshapefield is used to represent a tensor of unknown rank.A
Tensorwith a definedshaperepresents a tensor of known rank.Each
Dimensionof aTensorShapeProtocan have a known integer value (represented by thedim_valuefield) or it can have an unknown value represented by a symbolic identified (thedim_paramfield) or it may have neither field defined (in which case it represents an anonymous unknown value).
Invoking Shape Inference¶
Shape inference can be invoked either via C++ or Python. The Python API is described, with example, here.
The C++ API consists of a single function
shape_inference::InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry);
The first argument is a ModelProto to perform shape inference on,
which is annotated in-place with shape information. The second
argument is optional.
Limitations¶
Shape inference is not guaranteed to be complete. In particular, some dynamic behaviors block the flow of shape inference, for example a Reshape to a dynamically-provide shape. Also, all operators are not required to have a shape inference implementation.
Shape inference works only with constants and simple variables. It
does not support arithmetic expressions containing variables. For
example, Concat on tensors of shapes (5, 2) and (7, 2) can be
inferred to produce a result of shape (12, 2), but Concat on
tensors of shapes (5, 2) and (N, 2) will simply produce (M, 2),
rather than containing a representation of N+5. Note that differing
unknown symbolic values will be propagated, so the M here represents
an unknown quantity that is the same as other occurrences of M.
These limitations are a property of the current implementation, not fundamental constraints - if you are in need of something more advanced, do let us know!
Type Inference vs. Shape Inference¶
Type inference (determining the element type of outputs) is typically handled
automatically by the schema’s type constraints. When a type constraint variable
(e.g., "T") is shared between an input and an output in the schema definition,
the framework propagates the element type from the input to the output without
any explicit inference code.
However, many existing ops still explicitly call propagateElemTypeFromInputToOutput
as a best practice for robustness. This is harmless when type constraints already
cover the case, and ensures correct behavior regardless of how shape inference
is invoked.
Explicit type inference logic in TypeAndShapeInferenceFunction is only needed when:
The output type is determined by an attribute rather than an input type (e.g.,
Cast, where thetoattribute specifies the output element type)The output type differs from all input types in a way that cannot be expressed via shared type constraint variables
The operator uses heterogeneous variadic inputs/outputs (see below)
Homogeneous vs. Heterogeneous variadic inputs/outputs¶
The homogeneous/heterogeneous flag applies only to variadic (repeated) inputs or outputs in the schema definition:
Homogeneous (the default): All repeated arguments must have the same type. The type constraint variable constrains them to be identical, and the framework enforces and propagates this automatically.
Heterogeneous: Each repeated argument may have a distinct type. The type constraint variable only describes the set of allowed types — it does not constrain different arguments to have the same type. This is used by operators like
LoopandScan, whose carried state variables can have mixed types.
When using heterogeneous variadic arguments, the operator’s
TypeAndShapeInferenceFunction must explicitly propagate types for each
individual argument, since the framework cannot do it automatically.
Shape inference, on the other hand, almost always requires explicit logic, since output shapes typically depend on input shapes, attributes, or both.
Implementing Shape Inference For Operators¶
You can add a shape inference function to your operator’s Schema with
OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction is defined in
shape_inference.h, along with the core
interface struct InferenceContext and an assortment of helper
methods. InferenceContext is the core struct which is provided to
your inference function. It allows accessing information about the
operator’s inputs, and also allows writing out inferred information.
To see numerous examples, search for occurrences of
TypeAndShapeInferenceFunction in the codebase. One that is
relatively involved is the implementation for Concat, in
onnx/defs/tensor/defs.cc.
Please note the following points when implementing the shape-inference method for operators to avoid common errors:
Before accessing the
shapeof any input, the code must check that the shape is available. If unavailable, it should be treated as a dynamic tensor whose rank is unknown and handled appropriately. Usually, the shape-inference logic is guarded by a call tohasInputShapeorhasNInputShapes.Before accessing the
dim_valueordim_paramof any dimension, the code must check if these fields have a value. In particular, the code must handle the possibility that the dimension may not have a statically known value.
There are several utility functions in shape_inference.h to handle various common situations.
Use
checkInputRankfor inputs that must have a fixed rank. (See the inference forRoiAlignas an example.)unifyInputDimandunifyDimandupdateOutputShapecan be used when multiple input dims are expected to be the same, and when input dimensions are propagated to specific output dimensions. (See the inference forRoiAlignfor an example.)Overloaded operators
*and/can be used on symbolic dimensions when output dimensions are computed from input dimensions using arithmetic. (See the inference forSpaceToDepthfor an example.)
These utilities handle missing shapes and dimensions safely.
Example: Consider a simple matrix-multiplication op that expects inputs of shape
[M,K] and [K,N] and returns an output of shape [M,N]. This can be coded
up as below:
// Check that input 0 has rank 2 (if its rank is known).
checkInputRank(ctx, 0, 2);
// Check that input 1 has rank 2 (if its rank is known).
checkInputRank(ctx, 1, 2);
Dim M, K, N;
// Check various dimensions, handling missing dimensions/shapes safely.
unifyInputDim(ctx, 0, 0, M);
unifyInputDim(ctx, 0, 1, K);
unifyInputDim(ctx, 1, 0, K);
unifyInputDim(ctx, 1, 1, N);
updateOutputShape(ctx, 0, {M. N});