Source code for onnx_ir.passes.common.topological_sort

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Pass for topologically sorting the graphs."""

from __future__ import annotations

__all__ = [
    "TopologicalSortPass",
]


import onnx_ir as ir


[docs] class TopologicalSortPass(ir.passes.InPlacePass): """Topologically sort graphs and functions in a model.""" def call(self, model: ir.Model) -> ir.passes.PassResult: original_nodes = list(model.graph) model.graph.sort() sorted_nodes = list(model.graph) for function in model.functions.values(): original_nodes.extend(function) function.sort() sorted_nodes.extend(function) # Compare node orders to determine if any changes were made modified = False for node, new_node in zip(original_nodes, sorted_nodes): if node is not new_node: modified = True break return ir.passes.PassResult(model=model, modified=modified)