Source code for onnx_ir.traversal

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Utilities for traversing the IR graph."""

from __future__ import annotations

__all__ = [
    "RecursiveGraphIterator",
]

from collections.abc import Iterator, Reversible
from typing import Callable, Union

from typing_extensions import Self

from onnx_ir import _core, _enums

GraphLike = Union[_core.Graph, _core.Function, _core.GraphView]


[docs] class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
[docs] def __init__( self, graph_like: GraphLike, *, recursive: Callable[[_core.Node], bool] | None = None, reverse: bool = False, enter_graph: Callable[[GraphLike], None] | None = None, exit_graph: Callable[[GraphLike], None] | None = None, ): """Iterate over the nodes in the graph, recursively visiting subgraphs. This iterator allows for traversing the nodes of a graph and its subgraphs in a depth-first manner. It supports optional callbacks for entering and exiting subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs contained within nodes. .. versionadded:: 0.1.6 Added the `enter_graph` and `exit_graph` callbacks. Args: graph_like: The graph to traverse. recursive: A callback that determines whether to recursively visit the subgraphs contained in a node. If not provided, all nodes in subgraphs are visited. reverse: Whether to iterate in reverse order. enter_graph: An optional callback that is called when entering a subgraph. exit_graph: An optional callback that is called when exiting a subgraph. """ self._graph = graph_like self._recursive = recursive self._reverse = reverse self._iterator = self._recursive_node_iter(graph_like) self._enter_graph = enter_graph self._exit_graph = exit_graph
[docs] def __iter__(self) -> Self: self._iterator = self._recursive_node_iter(self._graph) return self
[docs] def __next__(self) -> _core.Node: return next(self._iterator)
def _recursive_node_iter( self, graph: _core.Graph | _core.Function | _core.GraphView ) -> Iterator[_core.Node]: iterable = reversed(graph) if self._reverse else graph if self._enter_graph is not None: self._enter_graph(graph) for node in iterable: # type: ignore[union-attr] yield node if self._recursive is not None and not self._recursive(node): continue yield from self._iterate_subgraphs(node) if self._exit_graph is not None: self._exit_graph(graph) def _iterate_subgraphs(self, node: _core.Node): for attr in node.attributes.values(): if not isinstance(attr, _core.Attr): continue if attr.type == _enums.AttributeType.GRAPH: if self._enter_graph is not None: self._enter_graph(attr.value) yield from RecursiveGraphIterator( attr.value, recursive=self._recursive, reverse=self._reverse, enter_graph=self._enter_graph, exit_graph=self._exit_graph, ) if self._exit_graph is not None: self._exit_graph(attr.value) elif attr.type == _enums.AttributeType.GRAPHS: graphs = reversed(attr.value) if self._reverse else attr.value for graph in graphs: if self._enter_graph is not None: self._enter_graph(graph) yield from RecursiveGraphIterator( graph, recursive=self._recursive, reverse=self._reverse, enter_graph=self._enter_graph, exit_graph=self._exit_graph, ) if self._exit_graph is not None: self._exit_graph(graph)
[docs] def __reversed__(self) -> Iterator[_core.Node]: return RecursiveGraphIterator( self._graph, recursive=self._recursive, reverse=not self._reverse, enter_graph=self._enter_graph, exit_graph=self._exit_graph, )