# Copyright (c) ONNX Project Contributors# SPDX-License-Identifier: Apache-2.0## This module implements some APIs described in# https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html# for the ONNX IR.# The classes {PassResult and PassManager} are derived from# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_base.py#L12# and# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_manager.py#L147# The original code is licensed under the PyTorch License https://github.com/pytorch/pytorch/blob/main/LICENSE"""Passes infrastructure for the IR."""from__future__importannotationsimportdataclassesimportloggingfromcollections.abcimportSequencefromtypingimportLiteral,final__all__=["PassBase","Sequential","InPlacePass","FunctionalPass","PassManager","PassResult",# Errors"InvariantError","PreconditionError","PostconditionError","PassError",]importabcimportonnx_irasirlogger=logging.getLogger(__name__)classInvariantError(Exception):"""Raised when an invariant is violated."""classPreconditionError(InvariantError):"""Raised when a precondition is violated."""classPostconditionError(InvariantError):"""Raised when a postcondition is violated."""classPassError(RuntimeError):"""Raised when an error occurs during a pass."""@dataclasses.dataclassclassPassResult:"""Result of a pass. Attributes: model: The transformed model. modified: Whether the resulting model is different from the input model. """model:ir.Modelmodified:boolclassPassBase(abc.ABC):"""Base class for all passes. ``in_place`` and ``changes_input`` properties and what they mean: +------------+------------------+----------------------------+ | | changes_inputs | not changes_inputs | +------------+------------------+----------------------------+ | in_place | in place | Side-effect-only pass | +------------+------------------+----------------------------+ | not | destructive | functional | | in_place | | | +------------+------------------+----------------------------+ """@property@abc.abstractmethoddefin_place(self)->bool:"""Whether the pass modifies the model in place and returns it. If True, the pass will return the same model object that was passed in. If False, the pass will return a new model object. """raiseNotImplementedError@property@abc.abstractmethoddefchanges_input(self)->bool:"""Whether the pass modifies input model."""raiseNotImplementedError@propertydefdestructive(self)->bool:"""Whether the pass will destroy the input model when ``in_place=False``. A pass is destructive if it is not in place and it modifies the input model. """returnnotself.in_placeandself.changes_inputdef__call__(self,model_or_result:ir.Model|PassResult,/)->PassResult:ifisinstance(model_or_result,PassResult):model=model_or_result.modelelse:model=model_or_result# Check preconditionstry:self.requires(model)exceptPreconditionError:raiseexceptExceptionase:raisePreconditionError(f"Pre-condition for pass '{self.__class__.__name__}' failed")fromeresult=self.call(model)# Check postconditionstry:self.ensures(model)exceptPostconditionError:raiseexceptExceptionase:raisePostconditionError(f"Post-condition for pass '{self.__class__.__name__}' failed")fromeifnotisinstance(result,PassResult):raiseTypeError(f"The result of the pass '{self.__class__.__name__}' should be type PassResult. ""Please create one with ir.passes.PassResult().")# Checks that the declared in-place property is respectedifself.in_placeandresult.modelisnotmodel:raisePassError(f"The pass '{self.__class__.__name__}' is declared in-place, ""but the model returned is *not* the same object as the input model. ""Pass developer: Pass should return the same model object or the in_place property should return False.")ifnotself.in_placeandresult.modelismodel:raisePassError(f"The pass '{self.__class__.__name__}' is declared not in-place, ""but the model returned *is* the same object as the input model. ""Pass developer: Pass should return a new model object or the in_place property should return True.")returnresult
[docs]@abc.abstractmethoddefcall(self,model:ir.Model)->PassResult:"""The main entry point for the pass."""...
[docs]defrequires(self,model:ir.Model)->None:"""Pre-conditions for the pass. This is optional to implement, will be called before call() if run by a pass manager. """delmodel# Unused
[docs]defensures(self,model:ir.Model)->None:"""Post-conditions for the pass. This is optional to implement, will be called after call() if run by a pass manager. """delmodel# Unused
classInPlacePass(PassBase):"""A pass that modifies the input model in place and returns it."""@property@finaldefin_place(self)->Literal[True]:"""An in-place pass is in place."""returnTrue@property@finaldefchanges_input(self)->Literal[True]:"""An in-place pass changes the input model."""returnTrueclassFunctionalPass(PassBase):"""A pass that returns a new model but does not modify the input model."""@property@finaldefin_place(self)->Literal[False]:"""A functional pass is not in place."""returnFalse@property@finaldefchanges_input(self)->Literal[False]:"""A functional pass does not change the input model."""returnFalseclassSequential(PassBase):"""Run a sequence of passes in order."""def__init__(self,*passes:PassBase):ifnotpasses:raiseValueError("Sequential must take at least one pass")self.passes=passesself._in_place=all(pass_.in_placeforpass_inpasses)# The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,# or if it is not designed to be in-place but somehow changes the input (destructive),# this pass sequence will change inputs.self._changes_input=self.passes[0].changes_inputorself.passes[0].in_place@propertydefin_place(self)->bool:returnself._in_place@propertydefchanges_input(self)->bool:returnself._changes_input
[docs]defcall(self,model:ir.Model)->PassResult:modified=Falsefori,pass_inenumerate(self.passes):logger.debug("Running the %s-th pass '%s'",i,pass_)try:pass_result=pass_(model)exceptExceptionase:prev_pass_names=[str(p)forpinself.passes[:i]]raisePassError(f"An error occurred when running the '{pass_}' pass after the "f"following passes: {prev_pass_names}")fromemodel=pass_result.modelmodified=modifiedorpass_result.modifiedreturnPassResult(model,modified)
classPassManager(Sequential):"""Pass manager for the IR. The PassManager is a Pass that runs a sequence of passes on a model. Attributes: passes: The passes to run. steps: The number of times to run the passes. early_stop: Whether to stop running the passes if the graph stops changing. """def__init__(self,passes:Sequence[PassBase],steps:int=1,early_stop:bool=True,):# TODO(justinchuby): Implement constraintssuper().__init__(*passes)self.steps=stepsself.early_stop=early_stop
[docs]defcall(self,model:ir.Model)->PassResult:"""Run the set of passes `steps` number of times or until the graph stops changing."""overall_modified=Falseforstepinrange(self.steps):try:# Call the call method of Sequentialstep_result=super().call(model)exceptExceptionase:raisePassError(f"An error occurred at step {step}")fromemodel=step_result.modelmodified=step_result.modifiedoverall_modified=overall_modifiedormodified# If the graph no longer changes, then we can stop running these passesifnotmodifiedandself.early_stop:logger.info("PassManager: No more graph changes detected after step %s",step)breakreturnPassResult(model,overall_modified)