Source code for onnx_ir.passes.common.onnx_checker
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Passes for debugging purposes."""
from __future__ import annotations
__all__ = [
"CheckerPass",
]
from typing import Literal
import onnx # noqa: TID251
import onnx_ir as ir
from onnx_ir.passes.common import _c_api_utils
[docs]
class CheckerPass(ir.passes.PassBase):
"""Run onnx checker on the model."""
@property
def in_place(self) -> Literal[True]:
"""This pass does not create a new model."""
return True
@property
def changes_input(self) -> Literal[False]:
"""This pass does not change the input model."""
return False
def __init__(
self,
full_check: bool = False,
skip_opset_compatibility_check: bool = False,
check_custom_domain: bool = False,
):
super().__init__()
self.full_check = full_check
self.skip_opset_compatibility_check = skip_opset_compatibility_check
self.check_custom_domain = check_custom_domain
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Run the onnx checker on the model."""
def _partial_check_model(proto: onnx.ModelProto) -> None:
"""Partial function to check the model."""
onnx.checker.check_model(
proto,
full_check=self.full_check,
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
check_custom_domain=self.check_custom_domain,
)
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
# The model is not modified
return ir.passes.PassResult(model, False)