[docs]classIdentityEliminationPass(ir.passes.InPlacePass):"""Pass for eliminating redundant Identity nodes. This pass removes Identity nodes according to the following rules: 1. For any node of the form `y = Identity(x)`, where `y` is not an output of any graph, replace all uses of `y` with a use of `x`, and remove the node. 2. If `y` is an output of a graph, and `x` is not an input of any graph, we can still do the elimination, but the value `x` should be renamed to be `y`. 3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate the node. It should be retained. """defcall(self,model:ir.Model)->ir.passes.PassResult:"""Main entry point for the identity elimination pass."""modified=False# Use RecursiveGraphIterator to process all nodes in the model graph and subgraphsfornodeinir.traversal.RecursiveGraphIterator(model.graph):ifself._try_eliminate_identity_node(node):modified=True# Process nodes in functionsforfunctioninmodel.functions.values():fornodeinir.traversal.RecursiveGraphIterator(function):ifself._try_eliminate_identity_node(node):modified=Trueifmodified:logger.info("Identity elimination pass modified the model")returnir.passes.PassResult(model,modified=modified)def_try_eliminate_identity_node(self,node:ir.Node)->bool:"""Try to eliminate a single identity node. Returns True if modified."""ifnode.op_type!="Identity"ornode.domain!="":returnFalseiflen(node.inputs)!=1orlen(node.outputs)!=1:# Invalid Identity node, skipreturnFalseinput_value=node.inputs[0]output_value=node.outputs[0]ifinput_valueisNone:# Cannot eliminate if input is NonereturnFalse# Get the graph that contains this nodegraph_like=node.graphassertgraph_likeisnotNone,"Node must be in a graph"output_is_graph_output=output_value.is_graph_output()input_is_graph_input=input_value.is_graph_input()# Case 3: Both output is graph output and input is graph input - keep the nodeifoutput_is_graph_outputandinput_is_graph_input:returnFalse# Case 1 & 2 (merged): Eliminate the identity node# Replace all uses of output with inputir.convenience.replace_all_uses_with(output_value,input_value)# If output is a graph output, we need to rename input and update graph outputsifoutput_is_graph_output:# Store the original output nameoriginal_output_name=output_value.name# Update the input value to have the output's nameinput_value.name=original_output_name# Update graph outputs to point to the input valueforidx,graph_outputinenumerate(graph_like.outputs):ifgraph_outputisoutput_value:graph_like.outputs[idx]=input_value# Remove the identity nodegraph_like.remove(node,safe=True)logger.debug("Eliminated identity node: %s",node)returnTrue