# Copyright (c) ONNX Project Contributors## SPDX-License-Identifier: Apache-2.0"""Graph utilities for checking whether an ONNX proto message is legal."""from__future__importannotations__all__=["check_attribute","check_function","check_graph","check_model","check_node","check_sparse_tensor","check_tensor","check_value_info","DEFAULT_CONTEXT","LEXICAL_SCOPE_CONTEXT","ValidationError","C","MAXIMUM_PROTOBUF",]importosimportsysfromtypingimportTYPE_CHECKING,Any,Callable,TypeVarimportonnx.defsimportonnx.onnx_cpp2py_export.checkerasC# noqa: N812importonnx.shape_inferencefromonnximport(IR_VERSION,AttributeProto,FunctionProto,GraphProto,ModelProto,NodeProto,SparseTensorProto,TensorProto,ValueInfoProto,)ifTYPE_CHECKING:fromgoogle.protobuf.messageimportMessage# Limitation of single protobuf file is 2GiBMAXIMUM_PROTOBUF=2147483648# TODO: This thing where we reserialize the protobuf back into the# string, only to deserialize it at the call site, is really goofy.# Stop doing that.# NB: Please don't edit this context!DEFAULT_CONTEXT=C.CheckerContext()DEFAULT_CONTEXT.ir_version=IR_VERSION# TODO: Maybe ONNX-ML should also be defaulted?DEFAULT_CONTEXT.opset_imports={"":onnx.defs.onnx_opset_version()}LEXICAL_SCOPE_CONTEXT=C.LexicalScopeContext()FuncType=TypeVar("FuncType",bound=Callable[...,Any])def_ensure_proto_type(proto:Message,proto_type:type[Message])->None:ifnotisinstance(proto,proto_type):raiseTypeError(f"The proto message needs to be of type '{proto_type.__name__}'")
[docs]defcheck_model(model:ModelProto|str|bytes|os.PathLike,full_check:bool=False,skip_opset_compatibility_check:bool=False,check_custom_domain:bool=False,)->None:"""Check the consistency of a model. An exception will be raised if the model's ir_version is not set properly or is higher than checker's ir_version, or if the model has duplicate keys in metadata_props. If IR version >= 3, the model must specify opset_import. If IR version < 3, the model cannot have any opset_import specified. Args: model: Model to check. If model is a path, the function checks model path first. If the model bytes size is larger than 2GB, function should be called using model path. full_check: If True, the function also runs shape inference check. skip_opset_compatibility_check: If True, the function skips the check for opset compatibility. check_custom_domain: If True, the function will check all domains. Otherwise only check built-in domains. """# If model is a path instead of ModelProtoifisinstance(model,(str,os.PathLike)):C.check_model_path(os.fspath(model),full_check,skip_opset_compatibility_check,check_custom_domain,)else:protobuf_string=(modelifisinstance(model,bytes)elsemodel.SerializeToString())# If the protobuf is larger than 2GiB,# remind users should use the model path to checkifsys.getsizeof(protobuf_string)>MAXIMUM_PROTOBUF:raiseValueError("This protobuf of onnx model is too large (>2GiB). Call check_model with model path instead.")C.check_model(protobuf_string,full_check,skip_opset_compatibility_check,check_custom_domain,)