Source code for onnx_ir.passes.common.shape_inference
# Copyright (c) ONNX Project Contributors# SPDX-License-Identifier: Apache-2.0"""Shape inference pass using onnx.shape_inference."""from__future__importannotations__all__=["ShapeInferencePass","infer_shapes",]importloggingimportonnx# noqa: TID251importonnx_irasirfromonnx_ir.passes.commonimport_c_api_utilslogger=logging.getLogger(__name__)def_merge_func(model:ir.Model,inferred_proto:onnx.ModelProto)->bool:"""Merge the shape inferred model with the original model. Args: model: The original IR model. inferred_proto: The ONNX model with shapes and types inferred. Returns: A tuple containing the modified model and a boolean indicating whether the model was modified. """inferred_model=ir.serde.deserialize_model(inferred_proto)modified=Falsefororiginal_graph,inferred_graphinzip(model.graphs(),inferred_model.graphs()):original_values=ir.convenience.create_value_mapping(original_graph)inferred_values=ir.convenience.create_value_mapping(inferred_graph)forname,valueinoriginal_values.items():ifnameininferred_values:inferred_value=inferred_values[name]ifvalue.shape!=inferred_value.shapeandinferred_value.shapeisnotNone:value.shape=inferred_value.shapemodified=Trueifvalue.dtype!=inferred_value.dtypeandinferred_value.dtypeisnotNone:value.dtype=inferred_value.dtypemodified=Trueelse:logger.warning("Value %s not found in inferred graph %s",name,inferred_graph.name)returnmodified
[docs]classShapeInferencePass(ir.passes.InPlacePass):"""This pass performs shape inference on the graph."""def__init__(self,check_type:bool=True,strict_mode:bool=True,data_prop:bool=True)->None:"""Initialize the shape inference pass. If inference fails, the model is left unchanged. Args: check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. data_prop: If True, use data propagation for shape inference. """super().__init__()self.check_type=check_typeself.strict_mode=strict_modeself.data_prop=data_propdefcall(self,model:ir.Model)->ir.passes.PassResult:defpartial_infer_shapes(proto:onnx.ModelProto)->onnx.ModelProto:returnonnx.shape_inference.infer_shapes(proto,check_type=self.check_type,strict_mode=self.strict_mode,data_prop=self.data_prop,)try:inferred_model_proto=_c_api_utils.call_onnx_api(partial_infer_shapes,model)exceptExceptionase:# pylint: disable=broad-exception-caughtlogger.warning("Shape inference failed: %s. Model is left unchanged",exc_info=e)returnir.passes.PassResult(model,False)modified=_merge_func(model,inferred_model_proto)returnir.passes.PassResult(model,modified=modified)
definfer_shapes(model:ir.Model,*,check_type:bool=True,strict_mode:bool=True,data_prop:bool=True,)->ir.Model:"""Perform shape inference on the model. Args: model: The model to perform shape inference on. check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. data_prop: If True, use data propagation for shape inference. Returns: The model with shape inference applied. """returnShapeInferencePass(check_type=check_type,strict_mode=strict_mode,data_prop=data_prop)(model).model