Source code for onnx_ir.passes.common.symbolic_shape_inference
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Symbolic shape inference pass."""
from __future__ import annotations
__all__ = [
"SymbolicShapeInferencePass",
]
import logging
import onnx_ir as ir
from onnx_ir.shape_inference import _context, _registry
logger = logging.getLogger(__name__)
[docs]
class SymbolicShapeInferencePass(ir.passes.InPlacePass):
"""Pass that performs symbolic shape inference on the graph.
This pass traverses the graph in topological order and applies
registered shape inference functions to each node. Unlike the
standard ONNX shape inference, this pass:
- Operates directly on the IR (no serialization)
- Supports symbolic expressions (e.g., N+1, batch*heads) via SymPy
- Is extensible via the shape inference registry
- Supports different merge policies for handling existing shapes
Example::
import onnx_ir as ir
from onnx_ir.passes.common import SymbolicShapeInferencePass
model = ir.load("model.onnx")
pass_ = SymbolicShapeInferencePass()
result = pass_(model)
# Or use the convenience function
from onnx_ir.shape_inference import infer_symbolic_shapes
model = infer_symbolic_shapes(model)
"""
def __init__(
self,
policy: _context.ShapeMergePolicy = "refine",
warn_on_missing: bool = True,
) -> None:
"""Initialize the symbolic shape inference pass.
Args:
policy: How to merge inferred shapes with existing shapes.
warn_on_missing: If True, log warnings for ops without registered
shape inference.
"""
# Import ops to trigger registration
from onnx_ir.shape_inference import _ops # noqa: F401
super().__init__()
self.policy = policy
self.warn_on_missing = warn_on_missing
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Run shape inference on the model.
Args:
model: The model to process.
Returns:
PassResult with the model and whether it was modified.
"""
ctx = _context.ShapeInferenceContext(model.opset_imports, policy=self.policy)
modified = False
# Process all graphs (main graph + subgraphs)
for graph in model.graphs():
graph_modified = self._process_graph(ctx, graph)
modified = modified or graph_modified
return ir.passes.PassResult(model, modified)
def _process_graph(self, ctx: _context.ShapeInferenceContext, graph: ir.Graph) -> bool:
"""Process a single graph.
Args:
ctx: The shape inference context.
graph: The graph to process.
Returns:
True if any shapes were modified.
"""
modified = False
warned_ops: set[tuple[str, str]] = set()
# Traverse nodes in topological order
for node in graph:
domain = node.domain or ""
op_type = node.op_type
opset_version = ctx.get_opset_version(domain)
# Look up shape inference function
infer_func = _registry.registry.get(domain, op_type, version=opset_version)
if infer_func is not None:
try:
# Track which outputs had shapes and dtypes before
old_states: list[tuple[ir.Shape | None, ir.TypeProtocol | None]] = []
for out in node.outputs:
old_states.append((out.shape, out.type))
# Run inference
infer_func(ctx, node)
# Check if any shapes or dtypes changed
for out, (old_shape, old_type) in zip(node.outputs, old_states):
if out.shape != old_shape or out.type != old_type:
modified = True
except Exception as e:
logger.warning(
"Shape inference failed for %s::%s: %s",
domain or "ai.onnx",
op_type,
e,
)
elif self.warn_on_missing:
key = (domain, op_type)
if key not in warned_ops:
logger.warning(
"No shape inference registered for %s::%s",
domain or "ai.onnx",
op_type,
)
warned_ops.add(key)
return modified