Source code for onnx_ir.passes.common.naming

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Name fix pass for ensuring unique names for all values and nodes."""

from __future__ import annotations

__all__ = [
    "NameFixPass",
    "NameGenerator",
    "SimpleNameGenerator",
]

import collections
import logging
from typing import Protocol

import onnx_ir as ir

logger = logging.getLogger(__name__)


class NameGenerator(Protocol):
    def generate_node_name(self, node: ir.Node) -> str:
        """Generate a preferred name for a node."""
        ...

    def generate_value_name(self, value: ir.Value) -> str:
        """Generate a preferred name for a value."""
        ...


class SimpleNameGenerator(NameGenerator):
    """Base class for name generation functions."""

    def generate_node_name(self, node: ir.Node) -> str:
        """Generate a preferred name for a node."""
        return node.name or "node"

    def generate_value_name(self, value: ir.Value) -> str:
        """Generate a preferred name for a value."""
        return value.name or "v"


[docs] class NameFixPass(ir.passes.InPlacePass): """Pass for fixing names to ensure all values and nodes have unique names. This pass ensures that: 1. Graph inputs and outputs have unique names (take precedence) 2. All intermediate values have unique names (assign names to unnamed values) 3. All values in subgraphs have unique names within their graph and parent graphs 4. All nodes have unique names within their graph The pass maintains global uniqueness across the entire model. You can customize the name generation functions for nodes and values by passing a subclass of :class:`NameGenerator`. For example, you can use a custom naming scheme like this:: class CustomNameGenerator: def custom_node_name(node: ir.Node) -> str: return f"custom_node_{node.op_type}" def custom_value_name(value: ir.Value) -> str: return f"custom_value_{value.type}" name_fix_pass = NameFixPass(name_generator=CustomNameGenerator()) .. versionadded:: 0.1.6 """ def __init__( self, name_generator: NameGenerator | None = None, ) -> None: """Initialize the NameFixPass with custom name generation functions. Args: name_generator (NameGenerator, optional): An instance of a subclass of :class:`NameGenerator` to customize name generation for nodes and values. If not provided, defaults to a basic implementation that uses the node's or value's existing name or a generic name like "node" or "v". """ super().__init__() self._name_generator = name_generator or SimpleNameGenerator() def call(self, model: ir.Model) -> ir.passes.PassResult: # Process the main graph modified = self._fix_graph_names(model.graph) # Process functions for function in model.functions.values(): modified = self._fix_graph_names(function) or modified return ir.passes.PassResult(model, modified=modified) def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool: """Fix names in a graph and return whether modifications were made.""" modified = False # Set to track which values have been assigned names seen_values: set[ir.Value] = set() # The first set is a dummy placeholder so that there is always a [-1] scope for access # (even though we don't write to it) scoped_used_value_names: list[set[str]] = [set()] scoped_used_node_names: list[set[str]] = [set()] # Counters for generating unique names (using list to pass by reference) value_counter = collections.Counter() node_counter = collections.Counter() def enter_graph(graph_like) -> None: """Callback for entering a subgraph.""" # Initialize new scopes with all names from the parent scope scoped_used_value_names.append(set(scoped_used_value_names[-1])) scoped_used_node_names.append(set()) nonlocal modified # Step 1: Fix graph input names first (they have precedence) for input_value in graph_like.inputs: if self._process_value( input_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True # Step 2: Fix graph output names (they have precedence) for output_value in graph_like.outputs: if self._process_value( output_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True if isinstance(graph_like, ir.Graph): # For graphs, also fix initializers for initializer in graph_like.initializers.values(): if self._process_value( initializer, scoped_used_value_names[-1], seen_values, value_counter ): modified = True def exit_graph(_) -> None: """Callback for exiting a subgraph.""" # Pop the current scope scoped_used_value_names.pop() scoped_used_node_names.pop() # Step 3: Process all nodes and their values for node in ir.traversal.RecursiveGraphIterator( graph_like, enter_graph=enter_graph, exit_graph=exit_graph ): # Fix node name if not node.name: if self._assign_node_name(node, scoped_used_node_names[-1], node_counter): modified = True else: if self._fix_duplicate_node_name( node, scoped_used_node_names[-1], node_counter ): modified = True # Fix input value names (only if not already processed) for input_value in node.inputs: if input_value is not None: if self._process_value( input_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True # Fix output value names (only if not already processed) for output_value in node.outputs: if self._process_value( output_value, scoped_used_value_names[-1], seen_values, value_counter ): modified = True return modified def _process_value( self, value: ir.Value, used_value_names: set[str], seen_values: set[ir.Value], value_counter: collections.Counter, ) -> bool: """Process a value only if it hasn't been processed before.""" if value in seen_values: return False modified = False if not value.name: modified = self._assign_value_name(value, used_value_names, value_counter) else: old_name = value.name modified = self._fix_duplicate_value_name(value, used_value_names, value_counter) if modified: assert value.graph is not None if value.is_initializer(): value.graph.initializers.pop(old_name) # Add the initializer back with the new name value.graph.initializers.add(value) # Record the final name for this value assert value.name is not None seen_values.add(value) return modified def _assign_value_name( self, value: ir.Value, used_names: set[str], counter: collections.Counter ) -> bool: """Assign a name to an unnamed value. Returns True if modified.""" assert not value.name, ( "value should not have a name already if function is called correctly" ) preferred_name = self._name_generator.generate_value_name(value) value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed value", value.name) return True def _assign_node_name( self, node: ir.Node, used_names: set[str], counter: collections.Counter ) -> bool: """Assign a name to an unnamed node. Returns True if modified.""" assert not node.name, ( "node should not have a name already if function is called correctly" ) preferred_name = self._name_generator.generate_node_name(node) node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter) logger.debug("Assigned name %s to unnamed node", node.name) return True def _fix_duplicate_value_name( self, value: ir.Value, used_names: set[str], counter: collections.Counter ) -> bool: """Fix a value's name if it conflicts with existing names. Returns True if modified.""" original_name = value.name assert original_name, ( "value should have a name already if function is called correctly" ) if original_name not in used_names: # Name is unique, just record it used_names.add(original_name) return False # If name is already used, make it unique base_name = self._name_generator.generate_value_name(value) value.name = _find_and_record_next_unique_name(base_name, used_names, counter) logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name) return True def _fix_duplicate_node_name( self, node: ir.Node, used_names: set[str], counter: collections.Counter ) -> bool: """Fix a node's name if it conflicts with existing names. Returns True if modified.""" original_name = node.name assert original_name, "node should have a name already if function is called correctly" if original_name not in used_names: # Name is unique, just record it used_names.add(original_name) return False # If name is already used, make it unique base_name = self._name_generator.generate_node_name(node) node.name = _find_and_record_next_unique_name(base_name, used_names, counter) logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name) return True
def _find_and_record_next_unique_name( preferred_name: str, used_names: set[str], counter: collections.Counter ) -> str: """Generate a unique name based on the preferred name and current counter.""" new_name = preferred_name while new_name in used_names: counter[preferred_name] += 1 new_name = f"{preferred_name}_{counter[preferred_name]}" used_names.add(new_name) return new_name