Source code for onnx.shape_inference

# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

"""onnx shape inference. Shape inference is not guaranteed to be
complete.

"""

from __future__ import annotations

import os
from typing import Sequence

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C  # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto


[docs] def infer_shapes( model: ModelProto | bytes, check_type: bool = False, strict_mode: bool = False, data_prop: bool = False, ) -> ModelProto: """Apply shape inference to the provided ModelProto. Inferred shapes are added to the value_info field of the graph. If the inferred values conflict with values already provided in the graph, that means that the provided values are invalid (or there is a bug in shape inference), and the result is unspecified. Arguments: model: ModelProto. check_type: Checks the type-equality for input and output. strict_mode: Stricter shape inference, it will throw errors if any; Otherwise, simply stop if any error. data_prop: Enables data propagation for limited operators to perform shape computation. Returns: (ModelProto) model with inferred shape information """ if isinstance(model, (ModelProto, bytes)): model_str = model if isinstance(model, bytes) else model.SerializeToString() inferred_model_str = C.infer_shapes( model_str, check_type, strict_mode, data_prop ) return onnx.load_from_string(inferred_model_str) if isinstance(model, str): raise TypeError( "infer_shapes only accepts ModelProto or bytes," "you can use infer_shapes_path for the model path (String)." ) raise TypeError( f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}" )
[docs] def infer_shapes_path( model_path: str | os.PathLike, output_path: str | os.PathLike = "", check_type: bool = False, strict_mode: bool = False, data_prop: bool = False, ) -> None: """Take model path for shape_inference. This function is the same as :func:`infer_shape` but supports >2GB models. The function outputs the inferred model to the `output_path`. The original model path is used if not specified. """ if isinstance(model_path, ModelProto): raise TypeError( "infer_shapes_path only accepts model Path (String)," "you can use infer_shapes for the ModelProto." ) try: model_path = os.fspath(model_path) except TypeError as exp: raise TypeError( "infer_shapes_path only accepts model path as a string or PathLike, " f"incorrect model path type: {type(model_path)}" ) from exp try: output_path = os.fspath(output_path) except TypeError as exp: raise TypeError( "infer_shapes_path only accepts output path as a string or PathLike, " f"incorrect output path type: {type(output_path)}" ) from exp if output_path == "": output_path = model_path C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
def infer_node_outputs( schema: onnx.defs.OpSchema, node: onnx.NodeProto, input_types: dict[str, onnx.TypeProto], input_data: dict[str, onnx.TensorProto] | None = None, input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None, opset_imports: list[onnx.OperatorSetIdProto] | None = None, ir_version: int = onnx.IR_VERSION, ) -> dict[str, onnx.TypeProto]: if not schema.has_type_and_shape_inference_function: # type: ignore return {} if input_data is None: input_data = {} if input_sparse_data is None: input_sparse_data = {} if opset_imports is None: passed_opset_imports = {} else: passed_opset_imports = {opset.domain: opset.version for opset in opset_imports} # catch KeyError if node's input does not exist in input_types passed_input_types = { key: input_types[key].SerializeToString() for key in node.input } # input_types will also be used as outer_scope_value_types so do not filter by node's input here for key in input_types: if key not in passed_input_types: passed_input_types[key] = input_types[key].SerializeToString() passed_input_data = { key: input_data[key].SerializeToString() for key in node.input if key in input_data } passed_sparse_input_data = { key: input_sparse_data[key].SerializeToString() for key in node.input if key in input_sparse_data } outputs = schema._infer_node_outputs( node.SerializeToString(), passed_input_types, passed_input_data, passed_sparse_input_data, passed_opset_imports, ir_version, ) # type: ignore[call-arg] return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
[docs] def infer_function_output_types( function: FunctionProto, input_types: Sequence[TypeProto], attributes: Sequence[AttributeProto], ) -> list[TypeProto]: """Apply type-and-shape-inference to given function body, with given input types and given input attribute values. """ result = C.infer_function_output_types( function.SerializeToString(), [x.SerializeToString() for x in input_types], [x.SerializeToString() for x in attributes], ) def to_type_proto(x) -> TypeProto: type_proto = onnx.TypeProto() type_proto.ParseFromString(x) return type_proto return [to_type_proto(x) for x in result]
InferenceError = C.InferenceError