Source code for onnx_ir.passes.common.shape_inference

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Shape inference pass using onnx.shape_inference."""

from __future__ import annotations

__all__ = [
    "ShapeInferencePass",
    "infer_shapes",
]

import logging

import onnx  # noqa: TID251

import onnx_ir as ir
from onnx_ir.passes.common import _c_api_utils

logger = logging.getLogger(__name__)


def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
    """Merge the shape inferred model with the original model.

    Args:
        model: The original IR model.
        inferred_proto: The ONNX model with shapes and types inferred.

    Returns:
        A tuple containing the modified model and a boolean indicating whether the model was modified.
    """
    inferred_model = ir.serde.deserialize_model(inferred_proto)
    modified = False
    for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
        original_values = ir.convenience.create_value_mapping(original_graph)
        inferred_values = ir.convenience.create_value_mapping(inferred_graph)
        for name, value in original_values.items():
            if name in inferred_values:
                inferred_value = inferred_values[name]
                if value.shape != inferred_value.shape and inferred_value.shape is not None:
                    value.shape = inferred_value.shape
                    modified = True
                if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
                    value.dtype = inferred_value.dtype
                    modified = True
            else:
                logger.warning(
                    "Value %s not found in inferred graph %s", name, inferred_graph.name
                )
    return modified


[docs] class ShapeInferencePass(ir.passes.InPlacePass): """This pass performs shape inference on the graph.""" def __init__( self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True ) -> None: """Initialize the shape inference pass. If inference fails, the model is left unchanged. Args: check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. data_prop: If True, use data propagation for shape inference. """ super().__init__() self.check_type = check_type self.strict_mode = strict_mode self.data_prop = data_prop def call(self, model: ir.Model) -> ir.passes.PassResult: def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: return onnx.shape_inference.infer_shapes( proto, check_type=self.check_type, strict_mode=self.strict_mode, data_prop=self.data_prop, ) try: inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) return ir.passes.PassResult(model, False) modified = _merge_func(model, inferred_model_proto) return ir.passes.PassResult(model, modified=modified)
def infer_shapes( model: ir.Model, *, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True, ) -> ir.Model: """Perform shape inference on the model. Args: model: The model to perform shape inference on. check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. data_prop: If True, use data propagation for shape inference. Returns: The model with shape inference applied. """ return ShapeInferencePass( check_type=check_type, strict_mode=strict_mode, data_prop=data_prop )(model).model