Source code for onnx_ir.journaling._journaling

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Journaling system for ONNX IR operations."""

from __future__ import annotations

import weakref
from typing import Any

__all__ = ["Journal", "JournalEntry", "get_current_journal"]

import dataclasses
import datetime
import time
import traceback
from collections.abc import Callable, Sequence

from typing_extensions import Self

from onnx_ir.journaling import _wrappers

_current_journal: Journal | None = None


@dataclasses.dataclass(frozen=True)
class JournalEntry:
    """A single journal entry recording an operation on the IR.

    Attributes:
        timestamp: The time at which the operation was performed.
        operation: The name of the operation performed.
        class_: The class of the object on which the operation was performed.
        class_name: The name of the class of the object.
        ref: A weak reference to the object on which the operation was performed.
            To access the object, call ``ref()``. Note that ``ref`` may be ``None``,
            and ``ref()`` may return ``None`` if the object has been garbage-collected.
        obj: The referenced object, or None if it has been garbage-collected or not recorded.
            This is the same as calling ``entry.ref() if entry.ref is not None else None``.
        object_id: The unique identifier of the object (id()).
        stack_trace: The stack trace at the time of the operation.
        details: Additional details about the operation.
    """

    timestamp: float
    operation: str
    class_: type
    class_name: str
    ref: weakref.ref | None
    object_id: int
    stack_trace: list[traceback.FrameSummary]
    details: str | None

    @property
    def obj(self) -> Any | None:
        """Get the referenced object, or None if it has been garbage-collected or not recorded."""
        if self.ref is None:
            return None
        return self.ref()

[docs] def display(self) -> None: """Display the journal entry in a detailed multi-line format.""" # Header with timestamp timestamp = datetime.datetime.fromtimestamp(self.timestamp).strftime( "%Y-%m-%d %H:%M:%S.%f" ) print(f"\033[1m{'=' * 80}\033[0m") print(f"\033[1mTimestamp:\033[0m {timestamp}") print(f"\033[1mOperation:\033[0m {self.operation}") print(f"\033[1mClass:\033[0m {self.class_name} (id={self.object_id})") # Object representation if self.ref is None: object_repr = "<no ref>" elif (obj := self.ref()) is not None: object_repr = repr(obj) else: object_repr = "<deleted>" print("\033[1mObject:\033[0m") for line in object_repr.split("\n"): print(f" {line}") # Details if self.details: print("\033[1mDetails:\033[0m") for line in self.details.split("\n"): print(f" {line}") # Stack trace - find user code frame if self.stack_trace: user_frame = None for f in reversed(self.stack_trace): filename = f.filename.replace("\\", "/") if "onnx_ir" not in filename or "onnx_ir/passes" in filename: user_frame = f break print("\033[1mUser Code Location:\033[0m") if user_frame is not None: print( f" \033[90m{user_frame.filename}:{user_frame.lineno} in {user_frame.name}\033[0m" ) if user_frame.line: print(f" \033[90m>>> {user_frame.line}\033[0m") else: print(" \033[90m<unknown>\033[0m") print("\033[1mFull Stack Trace:\033[0m") for f in self.stack_trace: print(f" \033[90m{f.filename}:{f.lineno} in {f.name}\033[0m") if f.line: print(f" \033[90m{f.line}\033[0m") print(f"\033[1m{'=' * 80}\033[0m")
def get_current_journal() -> Journal | None: """Get the current journal, if any.""" return _current_journal def _get_stack_trace() -> list[traceback.FrameSummary]: return traceback.extract_stack()[:-3] class Journal: """A journaling system to record operations on the ONNX IR. It can be used as a context manager to automatically record operations within a block. Example:: from onnx_ir.journaling import Journal journal = Journal() with Journal() as journal: # Perform operations on the ONNX IR pass for entry in journal.entries: print(f"Operation: {entry.operation} on {entry.class_name}") You can also filter the entries by operation or class name using the `filter` method:: filtered_entries = [ entry for entry in journal.entries if entry.operation == "set_name" and entry.class_name == "Node" ] """ def __init__(self) -> None: self._entries: list[JournalEntry] = [] self._previous_journal: Journal | None = None self._hooks: list[Callable[[JournalEntry], None]] = [] self._original_methods: dict[str, Callable] = {} def __enter__(self) -> Self: global _current_journal self._previous_journal = _current_journal _current_journal = self self._original_methods = _wrappers.wrap_ir_classes(self) return self def __exit__(self, exc_type, exc_value, exc_tb) -> None: _wrappers.restore_ir_classes(self._original_methods) global _current_journal _current_journal = self._previous_journal @property def entries(self) -> Sequence[JournalEntry]: """Get all recorded journal entries.""" return self._entries
[docs] def record(self, obj: Any, operation: str, details: str | None = None) -> None: """Record a new journal entry.""" entry = JournalEntry( timestamp=time.time(), operation=operation, class_=obj.__class__, class_name=obj.__class__.__name__, ref=weakref.ref(obj) if obj is not None else None, object_id=id(obj), stack_trace=_get_stack_trace(), details=details, ) self._entries.append(entry) for hook in self._hooks: hook(entry)
[docs] def add_hook(self, hook: Callable[[JournalEntry], None]) -> None: """Add a hook that will be called whenever a new journal entry is recorded.""" self._hooks.append(hook)
[docs] def clear_hooks(self) -> None: """Clear all hooks.""" self._hooks.clear()
[docs] def display(self) -> None: """Display all journal entries.""" for entry in self._entries: details = f" [{entry.details}]" if entry.details else "" timestamp = datetime.datetime.fromtimestamp(entry.timestamp).strftime( "%Y-%m-%d %H:%M:%S.%f" ) if entry.stack_trace: # Find the first frame that is not from internal onnx_ir modules frame = None for f in reversed(entry.stack_trace): # Normalize path separators for cross-platform compatibility filename = f.filename.replace("\\", "/") if "onnx_ir" not in filename or "onnx_ir/passes" in filename: frame = f break if frame is not None: location = f"{frame.filename}:{frame.lineno} in {frame.name}" else: location = "<unknown>" else: location = "<unknown>" print() print(f"[{timestamp}] \033[90m{location}\033[0m") if entry.ref is None: object_repr = "<no ref>" elif (obj := entry.ref()) is not None: object_repr = repr(obj).replace("\n", "\\n") if len(object_repr) > 100: object_repr = object_repr[:95] + "[...]" else: object_repr = "<deleted>" details_text = details.replace("\n", "\\n") if len(details_text) > 100: details_text = details_text[:95] + "[...]" print( f"Class: {entry.class_name}(id={entry.object_id}). Operation: {entry.operation}. Object: {object_repr}." ) if details: print(f"\033[90mDetails: {details_text}\033[0m")