Source code for onnx_ir._core

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""data structures for the intermediate representation."""

# NOTES for developers:
# NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because
# We cannot assume that the build tool chain has protoc installed and would like
# to keep this module protobuf free. This way we separate the concerns of the IR
# and the serialization/deserialization.
#
# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.

from __future__ import annotations

import abc
import contextlib
import dataclasses
import heapq
import math
import mmap
import os
import sys
import textwrap
import typing
from collections.abc import (
    Collection,
    Hashable,
    Iterable,
    Iterator,
    Mapping,
    MutableSequence,
    Sequence,
)
from collections.abc import (
    Set as AbstractSet,
)
from typing import (
    Any,
    Callable,
    Generic,
    NamedTuple,
    SupportsInt,
    Union,
)

import ml_dtypes
import numpy as np
from typing_extensions import TypeIs

import onnx_ir
from onnx_ir import (
    _display,
    _enums,
    _graph_containers,
    _linked_list,
    _metadata,
    _name_authority,
    _protocols,
    _type_casting,
)

if typing.TYPE_CHECKING:
    import numpy.typing as npt
    from typing_extensions import TypeGuard

TArrayCompatible = typing.TypeVar(
    "TArrayCompatible",
    bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible],
)

# System is little endian
_IS_LITTLE_ENDIAN = sys.byteorder == "little"
# Data types that are not supported by numpy
_NON_NUMPY_NATIVE_TYPES = frozenset(
    (
        _enums.DataType.BFLOAT16,
        _enums.DataType.FLOAT8E4M3FN,
        _enums.DataType.FLOAT8E4M3FNUZ,
        _enums.DataType.FLOAT8E5M2,
        _enums.DataType.FLOAT8E5M2FNUZ,
        _enums.DataType.INT4,
        _enums.DataType.UINT4,
        _enums.DataType.FLOAT4E2M1,
    )
)


def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
    """Use this function to check if an object is compatible with numpy.

    Avoid isinstance checks with the ArrayCompatible protocol for performance reasons.
    """
    return hasattr(obj, "__array__")


def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
    """Use this function to check if an object is compatible with DLPack.

    Avoid isinstance checks with the DLPackCompatible protocol for performance reasons.
    """
    return hasattr(obj, "__dlpack__")


class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
    """Convenience Shared methods for classes implementing TensorProtocol."""

    __slots__ = (
        "_doc_string",
        "_metadata",
        "_metadata_props",
        "_name",
    )

    def __init__(
        self,
        name: str | None = None,
        doc_string: str | None = None,
        metadata_props: dict[str, str] | None = None,
    ) -> None:
        self._metadata: _metadata.MetadataStore | None = None
        self._metadata_props: dict[str, str] | None = metadata_props
        self._name: str | None = name
        self._doc_string: str | None = doc_string

    def _printable_type_shape(self) -> str:
        """Return a string representation of the shape and data type."""
        return f"{self.dtype},{self.shape}"

    def _repr_base(self) -> str:
        """Base string for the repr method.

        Example: Tensor<FLOAT,[5,42]>
        """
        return f"{self.__class__.__name__}<{self._printable_type_shape()}>"

    @property
    def name(self) -> str | None:
        """The name of the tensor."""
        return self._name

    @name.setter
    def name(self, value: str | None) -> None:
        self._name = value

    @property
    def doc_string(self) -> str | None:
        """The documentation string."""
        return self._doc_string

    @doc_string.setter
    def doc_string(self, value: str | None) -> None:
        self._doc_string = value

    @property
    def size(self) -> int:
        """The number of elements in the tensor."""
        return math.prod(self.shape.numpy())  # type: ignore[attr-defined]

    @property
    def nbytes(self) -> int:
        """The number of bytes in the tensor."""
        # Use math.ceil because when dtype is INT4, the itemsize is 0.5
        return math.ceil(self.dtype.itemsize * self.size)

    @property
    def metadata_props(self) -> dict[str, str]:
        if self._metadata_props is None:
            self._metadata_props = {}
        return self._metadata_props

    @property
    def meta(self) -> _metadata.MetadataStore:
        """The metadata store for intermediate analysis.

        Write to the :attr:`metadata_props` if you would like the metadata to be serialized
        to the ONNX proto.
        """
        if self._metadata is None:
            self._metadata = _metadata.MetadataStore()
        return self._metadata

    def display(self, *, page: bool = False) -> None:
        rich = _display.require_rich()

        if rich is None:
            status_manager = contextlib.nullcontext()
        else:
            import rich.status  # type: ignore[import-not-found, no-redef]  # pylint: disable=import-outside-toplevel

            status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")

        from onnx_ir._thirdparty import (  # pylint: disable=import-outside-toplevel
            asciichartpy,
        )

        with status_manager:
            # Construct the text to display
            lines = []
            array = self.numpy().flatten()
            lines.append(repr(self))
            lines.append("")
            nan_values = np.isnan(array)
            nan_count = np.count_nonzero(nan_values)
            inf_count = np.count_nonzero(np.isinf(array))
            numbers = array[~nan_values]
            lines.append(
                f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, "
                f"NaN count: {nan_count}, "
                f"Inf count: {inf_count}"
            )
            # Compute sparsity
            sparse_threathold = 1e-6
            # NOTE: count_nonzero() is faster than sum() for boolean arrays
            sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size
            lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}")

            # Compute histogram
            finite_numbers = array[np.isfinite(array)]
            lines.append("Histogram:")
            hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False)
            lines.append(
                asciichartpy.plot(
                    hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"}
                )
            )

            text = "\n".join(lines)

        if rich is None:
            print(text)
        elif page:
            import rich.console  # type: ignore[import-not-found, no-redef]  # pylint: disable=import-outside-toplevel

            console = rich.console.Console()
            with console.pager():
                console.print(text)
        else:
            rich.print(text)


def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None:
    """Check if the numpy array dtype matches the IR data type.

    When the dtype is not one of the numpy native dtypes, the value needs need to be:

    - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
    - ``uint8`` for uint4 or float4.
    - ``uint8`` for 8-bit data types.
    - ``uint16`` for bfloat16

    or corresponding dtypes from the ``ml_dtype`` package.
    """
    if dtype in _NON_NUMPY_NATIVE_TYPES:
        if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
            raise TypeError(
                f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
            )
        if dtype.itemsize == 1 and array.dtype not in (
            np.uint8,
            ml_dtypes.float8_e4m3fnuz,
            ml_dtypes.float8_e4m3fn,
            ml_dtypes.float8_e5m2fnuz,
            ml_dtypes.float8_e5m2,
        ):
            raise TypeError(
                f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
            )
        if dtype == _enums.DataType.INT4:
            if array.dtype not in (np.int8, np.uint8, ml_dtypes.int4):
                raise TypeError(
                    f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int4 (not {array.dtype}) for IR data type {dtype}."
                )
        if dtype == _enums.DataType.UINT4:
            if array.dtype not in (np.uint8, ml_dtypes.uint4):
                raise TypeError(
                    f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
                )
        if dtype == _enums.DataType.FLOAT4E2M1:
            if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
                raise TypeError(
                    f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
                )
        return

    try:
        dtype_numpy = _enums.DataType.from_numpy(array.dtype)
    except TypeError as e:
        raise TypeError(
            "Failed to convert the numpy dtype to an IR data type. "
            "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when "
            "creating a Tensor."
        ) from e

    if dtype_numpy != dtype:
        raise TypeError(
            f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}."
        )


def _maybe_view_np_array_with_ml_dtypes(
    array: np.ndarray, dtype: _enums.DataType
) -> np.ndarray:
    """Reinterpret the array when it is a bit representation of a dtype not supported by numpy.

    Args:
        array: The numpy array to reinterpret.
        dtype: The data type to reinterpret the array as.

    Returns:
        The array reinterpreted as the dtype.
    """
    if dtype == _enums.DataType.BFLOAT16:
        return array.view(ml_dtypes.bfloat16)
    if dtype == _enums.DataType.FLOAT8E4M3FN:
        return array.view(ml_dtypes.float8_e4m3fn)
    if dtype == _enums.DataType.FLOAT8E4M3FNUZ:
        return array.view(ml_dtypes.float8_e4m3fnuz)
    if dtype == _enums.DataType.FLOAT8E5M2:
        return array.view(ml_dtypes.float8_e5m2)
    if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
        return array.view(ml_dtypes.float8_e5m2fnuz)
    if dtype == _enums.DataType.INT4:
        return array.view(ml_dtypes.int4)
    if dtype == _enums.DataType.UINT4:
        return array.view(ml_dtypes.uint4)
    if dtype == _enums.DataType.FLOAT4E2M1:
        return array.view(ml_dtypes.float4_e2m1fn)
    return array


class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):  # pylint: disable=too-many-ancestors
    """An immutable concrete tensor.

    This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
    compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object.
    The tensor is immutable and the data is not copied at initialization.

    To create a tensor from a numpy array::

        >>> import numpy as np
        >>> array = np.array([1, 2, 3])
        >>> tensor = Tensor(array)
        >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method
        >>> np.allclose(tensor, array)
        True

    To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor
    to a byte string for serialization, call :meth:`tobytes`.

    It is recommended to check the size of the tensor first before accessing the
    underlying data, because accessing the data may be expensive and incur IO
    overhead.

    Subclass this class to efficiently handle different types of tensors from different frameworks.

    Attributes:
        name: The name of the tensor.
        shape: The shape of the tensor.
        dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
        doc_string: Documentation string.
        raw: The raw data behind this tensor. It can be anything.
        size: The number of elements in the tensor.
        nbytes: The number of bytes in the tensor.
        metadata_props: Metadata that will be serialized to the ONNX file.
        meta: Metadata store for graph transform passes.
    """

    __slots__ = (
        "_dtype",
        "_raw",
        "_shape",
    )

    def __init__(
        self,
        value: TArrayCompatible,
        dtype: _enums.DataType | None = None,
        *,
        shape: Shape | None = None,
        name: str | None = None,
        doc_string: str | None = None,
        metadata_props: dict[str, str] | None = None,
    ) -> None:
        """Initialize a tensor.

        Args:
            value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
                When the dtype is not one of the numpy native dtypes, the value needs
                to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
                when the value is a numpy array; ``dtype`` must be specified in this case.
            dtype: The data type of the tensor. It can be None only when value is a numpy array.
                Users are responsible for making sure the dtype matches the value when value is not a numpy array.
            shape: The shape of the tensor. If None, the shape is obtained from the value.
            name: The name of the tensor.
            doc_string: The documentation string.
            metadata_props: The metadata properties.

        Raises:
            TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
            TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array.
            ValueError: If the shape is not specified and the value does not have a shape attribute.
            ValueError: If the dtype is not specified and the value is not a numpy array.
        """
        super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
        # NOTE: We should not do any copying here for performance reasons
        if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
            raise TypeError(f"Expected an array compatible object, got {type(value)}")
        if shape is None:
            # Obtain the shape from the value
            if not hasattr(value, "shape"):
                raise ValueError(
                    f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
                    "Please specify the shape explicitly."
                )
            self._shape = Shape(getattr(value, "shape"), frozen=True)  # noqa: B009
        else:
            self._shape = shape
            self._shape.freeze()
        if dtype is None:
            if isinstance(value, np.ndarray):
                self._dtype = _enums.DataType.from_numpy(value.dtype)
            else:
                raise ValueError(
                    "The dtype must be specified when the value is not a numpy array."
                )
        else:
            if isinstance(value, np.ndarray):
                # Make sure the dtype matches the value
                _check_numpy_representation_type(value, dtype)
            # Users are responsible for making sure the dtype matches the value
            # when value is not a numpy array
            self._dtype = dtype

        # View the bfloat16, float8 and int4 types using ml_dtypes
        if isinstance(value, np.ndarray):
            value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype)  # type: ignore[assignment]

        self._raw = value

    def __array__(self, dtype: Any = None) -> np.ndarray:
        if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
            return self._raw.__array__(dtype)
        assert _compatible_with_dlpack(self._raw), (
            f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
        )
        return np.from_dlpack(self._raw)

    def __dlpack__(self, *, stream: Any = None) -> Any:
        if _compatible_with_dlpack(self._raw):
            return self._raw.__dlpack__(stream=stream)
        return self.__array__().__dlpack__(stream=stream)

    def __dlpack_device__(self) -> tuple[int, int]:
        if _compatible_with_dlpack(self._raw):
            return self._raw.__dlpack_device__()
        return self.__array__().__dlpack_device__()

    def __repr__(self) -> str:
        # Avoid multi-line repr
        tensor_lines = repr(self._raw).split("\n")
        tensor_text = " ".join(line.strip() for line in tensor_lines)
        return f"{self._repr_base()}({tensor_text}, name={self.name!r})"

    @property
    def dtype(self) -> _enums.DataType:
        """The data type of the tensor. Immutable."""
        return self._dtype

    @property
    def shape(self) -> Shape:
        """The shape of the tensor. Immutable."""
        return self._shape

    @property
    def raw(self) -> TArrayCompatible:
        """Backing data of the tensor. Immutable."""
        return self._raw  # type: ignore[return-value]

[docs] def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. When the data type is not supported by numpy, the dtypes from the ``ml_dtype`` package are used. The values can be reinterpreted as bit representations using the ``.view()`` method. """ if isinstance(self._raw, np.ndarray): return self._raw # We do not cache the value to save memory return self.__array__()
[docs] def tobytes(self) -> bytes: """Returns the value as bytes encoded in little endian. Override this method for more efficient serialization when the raw value is not a numpy array. """ # TODO(justinchuby): Support DLPack array = self.numpy() if self.dtype in { _enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1, }: # Pack the array into int4 array = _type_casting.pack_int4(array) else: assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" if not _IS_LITTLE_ENDIAN: array = array.view(array.dtype.newbyteorder("<")) return array.tobytes()
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """An immutable concrete tensor with its data store on disk. This class uses memory mapping to avoid loading the tensor into memory, when the data type is supported by numpy. Otherwise, the tensor is loaded into memory lazily when accessed. Calling :attr:`shape` does not incur IO. Checking shape before loading the tensor is recommended if IO overhead and memory usage is a concern. To obtain an array, call :meth:`numpy`. To obtain the bytes, call :meth:`tobytes`. The :attr:`location` must be a relative path conforming to the ONNX specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed to be the full path to the data file. Users should expect that the :attr:`path` always leads to the correct file. At initialization, paths are not checked. It is the user's responsibility to ensure the paths are valid and accessible. Attributes: location: The location of the data file. It is the path relative to the base directory. base_dir: The base directory for the external data. It is used to resolve relative paths. At serialization, only the :attr:`location` is serialized into the "location" field of the ``TensorProto``. path: The path to the data file. This is computed by joining :attr:`base_dir` and :attr:`location`. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. shape: The shape of the tensor. name: The name of the tensor. It must be specified. doc_string: The documentation string. metadata_props: The metadata properties. """ __slots__ = ( "_array", "_base_dir", "_dtype", "_length", "_location", "_offset", "_shape", "_valid", "raw", ) def __init__( self, location: os.PathLike | str, offset: int | None, length: int | None, dtype: _enums.DataType, *, shape: Shape, name: str, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, base_dir: os.PathLike | str = "", ) -> None: """Initialize an external tensor. Args: location: The location of the data file. It is the path relative to the base directory. offset: The offset in bytes from the start of the file. length: The length of the data in bytes. dtype: The data type of the tensor. shape: The shape of the tensor. name: The name of the tensor.. doc_string: The documentation string. metadata_props: The metadata properties. base_dir: The base directory for the external data. It is used to resolve relative paths. """ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) # NOTE: Do not verify the location by default. This is because the location field # in the tensor proto can be anything and we would like deserialization from # proto to IR to not fail. if onnx_ir.DEBUG: if os.path.isabs(location): raise ValueError( "The location must be a relative path. Please specify base_dir as well." ) self._location = location self._base_dir = base_dir self._offset: int | None = offset self._length: int | None = length self._dtype: _enums.DataType = dtype self.name: str = name # mutable self._shape: Shape = shape self._shape.freeze() self.doc_string: str | None = doc_string # mutable self._array: np.ndarray | None = None self.raw: mmap.mmap | None = None self._metadata_props = metadata_props self._metadata: _metadata.MetadataStore | None = None self._valid = True @property def base_dir(self) -> str | os.PathLike: # Mutable return self._base_dir @base_dir.setter def base_dir(self, value: str | os.PathLike) -> None: self._base_dir = value @property def location(self) -> str | os.PathLike: # Immutable return self._location @property def path(self) -> str: # Immutable, computed return os.path.join(self._base_dir, self._location) @property def offset(self) -> int | None: # Immutable return self._offset @property def length(self) -> int | None: # Immutable return self._length @property def dtype(self) -> _enums.DataType: # Immutable return self._dtype @property def shape(self) -> Shape: # Immutable return self._shape def _load(self): self._check_validity() assert self._array is None, "Bug: The array should be loaded only once." if self.size == 0: # When the size is 0, mmap is impossible and meaningless self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy()) return # Map the whole file into the memory # TODO(justinchuby): Verify if this would exhaust the memory address space with open(self.path, "rb") as f: self.raw = mmap.mmap( f.fileno(), 0, access=mmap.ACCESS_READ, ) # Handle the byte order correctly by always using little endian dt = np.dtype(self.dtype.numpy()).newbyteorder("<") if self.dtype in { _enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1, }: # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values dt = np.dtype(np.uint8).newbyteorder("<") count = self.size // 2 + self.size % 2 else: count = self.size self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count) shape = self.shape.numpy() if self.dtype == _enums.DataType.INT4: # Unpack the int4 arrays self._array = _type_casting.unpack_int4(self._array, shape) elif self.dtype == _enums.DataType.UINT4: self._array = _type_casting.unpack_uint4(self._array, shape) elif self.dtype == _enums.DataType.FLOAT4E2M1: self._array = _type_casting.unpack_float4e2m1(self._array, shape) else: self._array = self._array.reshape(shape) def __array__(self, dtype: Any = None) -> np.ndarray: self._check_validity() if self._array is None: self._load() assert self._array is not None return self._array.__array__(dtype) def __dlpack__(self, *, stream: Any = None) -> Any: raise NotImplementedError( "ExternalTensor does not support DLPack because it uses memory mapping. " "Call numpy() to get a numpy array instead." ) def __dlpack_device__(self) -> tuple[int, int]: raise NotImplementedError( "ExternalTensor does not support DLPack because it uses memory mapping. " "Call numpy() to get a numpy array instead." ) def __repr__(self) -> str: return ( f"{self._repr_base()}(location='{self.location}', name={self.name!r}, " f"offset={self.offset!r}, length={self.length!r}, base_dir={self.base_dir!r})" )
[docs] def numpy(self) -> np.ndarray: """Return the tensor as a numpy array. The data will be memory mapped into memory and will not taken up physical memory space. """ self._check_validity() if self._array is None: self._load() assert self._array is not None return self._array
[docs] def tobytes(self) -> bytes: """Return the bytes of the tensor. This will load the tensor into memory. """ self._check_validity() if self.raw is None: self._load() assert self.raw is not None offset = self._offset or 0 length = self._length or self.nbytes return self.raw[offset : offset + length]
[docs] def valid(self) -> bool: """Check if the tensor is valid. The external tensor is valid if it has not been invalidated. """ return self._valid
def _check_validity(self) -> None: if not self.valid(): raise ValueError( f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted." )
[docs] def invalidate(self) -> None: """Invalidate the tensor. The external tensor is invalidated when the data is known to be corrupted or deleted. """ self._valid = False
[docs] def release(self) -> None: """Delete all references to the memory buffer and close the memory-mapped file.""" self._array = None if self.raw is not None: self.raw.close() self.raw = None
class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """Multidimensional array of strings (as binary data to match the string_data field in TensorProto).""" __slots__ = ( "_raw", "_shape", ) def __init__( self, value: Sequence[bytes] | npt.NDArray[np.bytes_], *, shape: Shape | None = None, name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ) -> None: """Initialize a tensor. Args: value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes. shape: The shape of the tensor. If None, the shape is obtained from the value. name: The name of the tensor. doc_string: The documentation string. metadata_props: The metadata properties. """ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) if shape is None: if not hasattr(value, "shape"): raise ValueError( f"Expected an object with a shape attribute, but {type(value)} does not have shape. " "Please specify the shape explicitly." ) self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009 else: self._shape = shape self._shape.freeze() self._raw = value def __array__(self, dtype: Any = None) -> np.ndarray: if isinstance(self._raw, np.ndarray): return self._raw assert isinstance(self._raw, Sequence), ( f"Bug: Expected a sequence, got {type(self._raw)}" ) return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy()) def __dlpack__(self, *, stream: Any = None) -> Any: del stream # unused raise TypeError("StringTensor does not support DLPack") def __dlpack_device__(self) -> tuple[int, int]: raise TypeError("StringTensor does not support DLPack") def __repr__(self) -> str: return f"{self._repr_base()}({self._raw!r}, name={self.name!r})" @property def dtype(self) -> _enums.DataType: """The data type of the tensor. Immutable.""" return _enums.DataType.STRING @property def shape(self) -> Shape: """The shape of the tensor. Immutable.""" return self._shape @property def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]: """Backing data of the tensor. Immutable.""" return self._raw # type: ignore[return-value]
[docs] def numpy(self) -> npt.NDArray[np.bytes_]: """Return the tensor as a numpy array.""" return self.__array__()
[docs] def tobytes(self) -> bytes: raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")
[docs] def string_data(self) -> Sequence[bytes]: """Return the string data of the tensor.""" if isinstance(self._raw, np.ndarray): return self._raw.flatten().tolist() return self._raw
class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors """A tensor that lazily evaluates a function to get the actual tensor. This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument. The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called. Example:: >>> import numpy as np >>> import onnx_ir as ir >>> weights = np.array([[1, 2, 3]]) >>> def create_tensor(): # Delay applying transformations to the weights ... weights_t = weights.transpose() ... return ir.tensor(weights_t) >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3])) >>> print(lazy_tensor.numpy()) [[1] [2] [3]] Attributes: func: The function that returns the actual tensor. dtype: The data type of the tensor. shape: The shape of the tensor. cache: Whether to cache the result of the function. If False, the function is called every time the tensor content is accessed. If True, the function is called only once and the result is cached in memory. Default is False. name: The name of the tensor. doc_string: The documentation string. metadata_props: The metadata properties. """ __slots__ = ( "_dtype", "_func", "_shape", "_tensor", "cache", ) def __init__( self, func: Callable[[], _protocols.TensorProtocol], dtype: _enums.DataType, shape: Shape, *, cache: bool = False, name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ) -> None: """Initialize a lazy tensor. Args: func: The function that returns the actual tensor. dtype: The data type of the tensor. shape: The shape of the tensor. cache: Whether to cache the result of the function. name: The name of the tensor. doc_string: The documentation string. metadata_props: The metadata properties. """ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) self._func = func self._dtype = dtype self._shape = shape self._tensor: _protocols.TensorProtocol | None = None self.cache = cache def _evaluate(self) -> _protocols.TensorProtocol: """Evaluate the function to get the actual tensor.""" if not self.cache: return self._func() # Cache the tensor if self._tensor is None: self._tensor = self._func() return self._tensor def __array__(self, dtype: Any = None) -> np.ndarray: return self._evaluate().__array__(dtype) def __dlpack__(self, *, stream: Any = None) -> Any: return self._evaluate().__dlpack__(stream=stream) def __dlpack_device__(self) -> tuple[int, int]: return self._evaluate().__dlpack_device__() def __repr__(self) -> str: return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})" @property def raw(self) -> Callable[[], _protocols.TensorProtocol]: return self._func @property def dtype(self) -> _enums.DataType: """The data type of the tensor. Immutable.""" return self._dtype @property def shape(self) -> Shape: """The shape of the tensor. Immutable.""" return self._shape
[docs] def numpy(self) -> np.ndarray: """Return the tensor as a numpy array.""" return self._evaluate().numpy()
[docs] def tobytes(self) -> bytes: """Return the bytes of the tensor.""" return self._evaluate().tobytes()
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): """Immutable symbolic dimension that can be shared across multiple shapes.""" __slots__ = ("_value",) def __init__(self, value: str | None) -> None: """Initialize a symbolic dimension. Args: value: The value of the dimension. It should not be an int. """ if isinstance(value, int): raise TypeError( "The value of a SymbolicDim cannot be an int. " "If you are creating a Shape, use int directly instead of SymbolicDim." ) self._value = value def __eq__(self, other: object) -> bool: if not isinstance(other, SymbolicDim): return self.value == other return self.value == other.value def __hash__(self) -> int: return hash(self.value) @property def value(self) -> str | None: return self._value def __str__(self) -> str: return f"{self._value}" def __repr__(self) -> str: return f"{self.__class__.__name__}({self._value})" def _is_int_compatible(value: object) -> TypeIs[SupportsInt]: """Return True if the value is int compatible.""" if isinstance(value, int): return True if hasattr(value, "__int__"): # For performance reasons, we do not use isinstance(value, SupportsInt) return True return False def _maybe_convert_to_symbolic_dim( dim: int | SupportsInt | SymbolicDim | str | None, ) -> SymbolicDim | int: """Convert the value to a SymbolicDim if it is not an int.""" if dim is None or isinstance(dim, str): return SymbolicDim(dim) if _is_int_compatible(dim): return int(dim) if isinstance(dim, SymbolicDim): return dim raise TypeError( f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" ) class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): """The shape of a tensor, including its dimensions and optional denotations. The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or symbolic dimensions. A shape can be compared to another shape or plain Python list. A shape can be frozen (made immutable). When the shape is frozen, it cannot be unfrozen, making it suitable to be shared across tensors or values. Call :method:`freeze` to freeze the shape. To update the dimension of a frozen shape, call :method:`copy` to create a new shape with the same dimensions that can be modified. Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations. Example:: >>> import onnx_ir as ir >>> shape = ir.Shape(["B", None, 3]) >>> shape.rank() 3 >>> shape.is_static() False >>> shape.is_dynamic() True >>> shape.is_static(dim=2) True >>> shape[0] = 1 >>> shape[1] = 2 >>> shape.dims (1, 2, 3) >>> shape == [1, 2, 3] True >>> shape.frozen False >>> shape.freeze() >>> shape.frozen True Attributes: dims: A tuple of dimensions representing the shape. Each dimension can be an integer, None or a :class:`SymbolicDim`. frozen: Indicates whether the shape is immutable. When frozen, the shape cannot be modified or unfrozen. """ __slots__ = ("_dims", "_frozen") def __init__( self, dims: Iterable[int | SupportsInt | SymbolicDim | str | None], /, denotations: Iterable[str | None] | None = None, frozen: bool = False, ) -> None: """Initialize a shape. Args: dims: The dimensions of the shape. Each dimension can be an integer or a SymbolicDim or any Python object. When a ``dim`` is not an integer or a SymbolicDim, it is converted to a SymbolicDim. denotations: The denotations of the dimensions. If None, the denotations are not set. Standard denotation can optionally be used to denote tensor dimensions with standard semantic descriptions to ensure that operations are applied to the correct axis of a tensor. Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition for pre-defined dimension denotations. frozen: If True, the shape is immutable and cannot be modified. This is useful when the shape is initialized by a Tensor or when the shape is shared across multiple tensors. The default is False. """ self._dims: list[int | SymbolicDim] = [ _maybe_convert_to_symbolic_dim(dim) for dim in dims ] self._denotations: list[str | None] = ( list(denotations) if denotations is not None else [None] * len(self._dims) ) if len(self._denotations) != len(self._dims): raise ValueError( "The number of denotations, when provided, must be equal to the number of dimensions." ) self._frozen: bool = frozen @property def dims(self) -> tuple[int | SymbolicDim, ...]: """All dimensions in the shape. This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape. """ return tuple(self._dims) @property def frozen(self) -> bool: """Whether the shape is frozen. When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a new shape with the same dimensions that can be modified. """ return self._frozen
[docs] def freeze(self) -> None: """Freeze the shape. When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. """ self._frozen = True
[docs] def copy(self, frozen: bool = False): """Return a copy of the shape.""" return Shape(self._dims, self._denotations, frozen=frozen)
[docs] def rank(self) -> int: """The rank of the tensor this shape represents.""" return len(self._dims)
[docs] def numpy(self) -> tuple[int, ...]: if any(not isinstance(dim, int) for dim in self._dims): raise ValueError(f"Cannot convert the shape {self} to a tuple of ints") return tuple(dim for dim in self._dims) # type: ignore
def __len__(self) -> int: return len(self._dims) def __iter__(self) -> Iterator[int | SymbolicDim]: return iter(self._dims) @typing.overload def __getitem__(self, index: int) -> int | SymbolicDim: ... @typing.overload def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... def __getitem__(self, index): return tuple(self._dims)[index] def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None: """Set the dimension at the index. Args: index: The index of the dimension. value: The value of the dimension. Raises: TypeError: If the shape is frozen and cannot be modified. TypeError: If the value is not an int or SymbolicDim. """ if self._frozen: raise TypeError("The shape is frozen and cannot be modified.") self._dims[index] = _maybe_convert_to_symbolic_dim(value)
[docs] def get_denotation(self, index: int) -> str | None: """Return the denotation of the dimension at the index. Args: index: The index of the dimension. Returns: The denotation of the dimension. """ return self._denotations[index]
[docs] def set_denotation(self, index: int, denotation: str | None) -> None: """Set the denotation of the dimension at the index. Args: index: The index of the dimension. denotation: The denotation of the dimension. """ self._denotations[index] = denotation
def __repr__(self) -> str: return f"{self.__class__.__name__}({self._dims!r})" def __str__(self) -> str: """Return a string representation of the shape. E.g. [n,1,3] """ return f"[{','.join([str(dim) for dim in self._dims])}]" def __eq__(self, other: object) -> bool: """Return True if the shapes are equal. Two shapes are equal if all their dimensions are equal. """ if isinstance(other, Shape): return self._dims == other._dims if not isinstance(other, Iterable): return False return self._dims == list(other) def __ne__(self, other: object) -> bool: return not self.__eq__(other) @typing.overload def is_static(self, dim: int) -> bool: # noqa: D418 """Return True if the dimension is static.""" @typing.overload def is_static(self) -> bool: # noqa: D418 """Return True if all dimensions are static."""
[docs] def is_static(self, dim=None) -> bool: """Return True if the dimension is static. If dim is None, return True if all dimensions are static.""" if dim is None: return all(isinstance(dim, int) for dim in self._dims) return isinstance(self[dim], int)
@typing.overload def is_dynamic(self, dim: int) -> bool: # noqa: D418 """Return True if the dimension is dynamic.""" @typing.overload def is_dynamic(self) -> bool: # noqa: D418 """Return True if any dimension is dynamic."""
[docs] def is_dynamic(self, dim=None) -> bool: if dim is None: return not self.is_static() return not self.is_static(dim)
def _quoted(string: str) -> str: """Return a quoted string. This function is used to quote value/node names in the IR for better readability. """ return f'"{string}"' class Usage(NamedTuple): """A usage of a value in a node. Attributes: node: The node that uses the value. idx: The input index of the value in the node. """ node: Node idx: int def _short_tensor_str_for_node(x: Value) -> str: if x.const_value is None: return "" if x.const_value.size <= 10: try: data = x.const_value.numpy().tolist() except Exception: # pylint: disable=broad-except return "{...}" return f"{{{data}}}" return "{...}" def _normalize_domain(domain: str) -> str: """Normalize 'ai.onnx' to ''.""" return "" if domain == "ai.onnx" else domain class Node(_protocols.NodeProtocol, _display.PrettyPrintable): """IR Node. If the ``graph`` is provided, the node will be added to the graph. Otherwise, user is responsible to call ``graph.append(node)`` (or other mutation methods in :class:`Graph`) to add the node to the graph. After the node is initialized, it will add itself as a user of the input values. The output values of the node are created during node initialization and are immutable. To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with the new output values by calling :meth:`replace_input_with` on the using nodes of this node's outputs. .. note: When the ``domain`` is `"ai.onnx"`, it is normalized to `""`. """ __slots__ = ( "_attributes", "_domain", "_graph", "_inputs", "_metadata", "_metadata_props", "_name", "_op_type", "_outputs", "_overload", "_version", "doc_string", ) def __init__( self, domain: str, op_type: str, inputs: Iterable[Value | None], attributes: Iterable[Attr] | Mapping[str, Attr] = (), *, overload: str = "", num_outputs: int | None = None, outputs: Sequence[Value] | None = None, version: int | None = None, graph: Graph | Function | None = None, name: str | None = None, doc_string: str | None = None, metadata_props: dict[str, str] | None = None, ): """Initialize a node and add it as a user of the input values. Args: domain: The domain of the operator. For onnx operators, this is an empty string. When it is `"ai.onnx"`, it is normalized to `""`. op_type: The name of the operator. inputs: The input values. When an input is ``None``, it is an empty input. attributes: The attributes. RefAttr can be used only when the node is defined in a Function. overload: The overload name when the node is invoking a function. num_outputs: The number of outputs of the node. If not specified, the number is 1. outputs: The output values. If ``None``, the outputs are created during initialization. version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph. graph: The graph that the node belongs to. If ``None``, the node is not added to any graph. A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph of the function is assigned to the node. name: The name of the node. If ``None``, the node is anonymous. The name may be set by a :class:`Graph` if ``graph`` is specified. doc_string: The documentation string. metadata_props: The metadata properties. Raises: TypeError: If the attributes are not :class:`Attr`. ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs. ValueError: If an output value is ``None``, when outputs is specified. ValueError: If an output value has a producer set already, when outputs is specified. """ self._name = name self._domain: str = _normalize_domain(domain) self._op_type: str = op_type # NOTE: Make inputs immutable with the assumption that they are not mutated # very often. This way all mutations can be tracked. # If necessary, we can cache the inputs and outputs as tuples. self._inputs: tuple[Value | None, ...] = tuple(inputs) # Values belong to their defining nodes. The values list is immutable self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs) if isinstance(attributes, Mapping): attributes = tuple(attributes.values()) self._attributes: _graph_containers.Attributes = _graph_containers.Attributes( attributes ) self._overload: str = overload # TODO(justinchuby): Potentially support a version range self._version: int | None = version self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props # _graph is set by graph.append self._graph: Graph | None = None # Add the node to the graph if graph is specified if graph is not None: graph.append(self) self.doc_string = doc_string # Add the node as a use of the inputs for i, input_value in enumerate(self._inputs): if input_value is not None: input_value._add_usage(self, i) # pylint: disable=protected-access def _create_outputs( self, num_outputs: int | None, outputs: Sequence[Value] | None ) -> tuple[Value, ...]: """Check the parameters and create outputs for the node. Args: num_outputs: The number of outputs of the node. outputs: The output values of the node. Returns: The output values of the node. Raises: ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs. ValueError: If an output value is None. ValueError: If an output value has a producer set already. """ # Check num_outputs and outputs are consistent if num_outputs is not None and outputs is not None and num_outputs != len(outputs): raise ValueError( "num_outputs must be the same as len(outputs) when num_outputs is specified." f"num_outputs: {num_outputs}, outputs: {outputs}" ) # 1. If outputs is specified (can be empty []), use the outputs if outputs is not None: # Check all output values are valid first for output in outputs: if output is None: raise ValueError(f"Output value cannot be None. All outputs: {outputs}") if output.producer() is not None: raise ValueError( f"Supplied output value cannot have a producer when used for initializing a Node. " f"Output: {output}. All outputs: {outputs}" ) result = [] for i, output in enumerate(outputs): output._producer = self # pylint: disable=protected-access output._index = i # pylint: disable=protected-access result.append(output) return tuple(result) # 2. If num_outputs is specified, create num_outputs outputs if num_outputs is None: # Default to 1 output num_outputs = 1 assert num_outputs is not None return tuple(Value(self, index=i) for i in range(num_outputs)) def __str__(self) -> str: node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * ( self._overload != "" ) inputs_text = ( "(" + ", ".join( [ ( f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}" if x is not None else "None" ) for x in self._inputs ] ) + ")" ) attributes_text = ( (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}") if self._attributes else "" ) outputs_text = ", ".join(str(x) for x in self._outputs) return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}" def __repr__(self) -> str: return ( f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, " f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, " f"overload={self._overload!r}, outputs={self._outputs!r}, " f"version={self._version!r}, doc_string={self.doc_string!r})" ) @property def name(self) -> str | None: """Optional name of the node.""" return self._name @name.setter def name(self, value: str | None) -> None: self._name = value @property def domain(self) -> str: """The domain of the operator. For onnx operators, this is an empty string. .. note: When domain is `"ai.onnx"`, it is normalized to `""`. """ return self._domain @domain.setter def domain(self, value: str) -> None: self._domain = _normalize_domain(value) @property def version(self) -> int | None: """Opset version of the operator called. If ``None``, the version is unspecified and will follow that of the graph. This property is special to ONNX IR to allow mixed opset usage in a graph for supporting more flexible graph transformations. It does not exist in the ONNX serialization (protobuf) spec. """ return self._version @version.setter def version(self, value: int | None) -> None: self._version = value @property def op_type(self) -> str: """The name of the operator called.""" return self._op_type @op_type.setter def op_type(self, value: str) -> None: self._op_type = value @property def overload(self) -> str: """The overload name when the node is invoking a function.""" return self._overload @overload.setter def overload(self, value: str) -> None: self._overload = value @property def inputs(self) -> Sequence[Value | None]: """The input values of the node. The inputs are immutable. To change the inputs, create a new node and replace the inputs of the using nodes of this node's outputs by calling :meth:`replace_input_with` on the using nodes of this node's outputs. """ return self._inputs @inputs.setter def inputs(self, _: Any) -> None: raise AttributeError( "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." )
[docs] def predecessors(self) -> Sequence[Node]: """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes predecessors: dict[Node, None] = {} for value in self.inputs: if value is not None and (producer := value.producer()) is not None: predecessors[producer] = None return tuple(predecessors)
[docs] def successors(self) -> Sequence[Node]: """Return the successor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes successors: dict[Node, None] = {} for value in self.outputs: assert value is not None, "Bug: Output values are not expected to be None" for usage in value.uses(): successors[usage.node] = None return tuple(successors)
[docs] def replace_input_with(self, index: int, value: Value | None) -> None: """Replace an input with a new value.""" if index < 0 or index >= len(self.inputs): raise ValueError(f"Index out of range: {index}") old_input = self.inputs[index] self._inputs = tuple( value if i == index else old_input for i, old_input in enumerate(self.inputs) ) if old_input is not None: old_input._remove_usage(self, index) # pylint: disable=protected-access if value is not None: value._add_usage(self, index) # pylint: disable=protected-access
[docs] def prepend(self, /, nodes: Node | Iterable[Node]) -> None: """Insert a node before this node in the list of nodes in the graph. It is the same as calling ``graph.insert_before(self, nodes)``. Example:: Before: previous_node -> self previous_node' -> node -> next_node' After: previous_node -> node -> self previous_node' -> next_node' Args: nodes: A node or a sequence of nodes to put before this node. """ if self._graph is None: raise ValueError("The node to prepend to does not belong to any graph.") self._graph.insert_before(self, nodes)
[docs] def append(self, /, nodes: Node | Iterable[Node]) -> None: """Insert a node after this node in the list of nodes in the graph. It is the same as calling ``graph.insert_after(self, nodes)``. Example:: Before: previous_node -> self previous_node' -> node -> next_node' After: previous_node -> self -> node previous_node' -> next_node' Args: nodes: A node or a sequence of nodes to put after this node. """ if self._graph is None: raise ValueError("The node to append to does not belong to any graph.") self._graph.insert_after(self, nodes)
@property def outputs(self) -> Sequence[Value]: """The output values of the node. The outputs are immutable. To change the outputs, create a new node and replace the inputs of the using nodes of this node's outputs by calling :meth:`replace_input_with` on the using nodes of this node's outputs. """ return self._outputs @outputs.setter def outputs(self, _: Sequence[Value]) -> None: raise AttributeError("outputs is immutable. Please create a new node instead.") @property def attributes(self) -> _graph_containers.Attributes: """The attributes of the node.""" return self._attributes @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ if self._metadata is None: self._metadata = _metadata.MetadataStore() return self._metadata @property def metadata_props(self) -> dict[str, str]: """The metadata properties of the node. The metadata properties are used to store additional information about the node. Unlike ``meta``, this property is serialized to the ONNX proto. """ if self._metadata_props is None: self._metadata_props = {} return self._metadata_props @property def graph(self) -> Graph | None: """The graph that the node belongs to. If the node is not added to any graph, this property is None. """ return self._graph @graph.setter def graph(self, value: Graph | None) -> None: self._graph = value
[docs] def op_identifier(self) -> _protocols.OperatorIdentifier: """Return the operator identifier of the node. The operator identifier is a tuple of the domain, op_type and overload. """ return self.domain, self.op_type, self.overload
[docs] def display(self, *, page: bool = False) -> None: """Pretty print the node. This method is used for debugging and visualization purposes. """ # Add the node's name to the displayed text print(f"Node: {self.name!r}") if self.doc_string: print(f"Doc: {self.doc_string}") super().display(page=page)
class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable): """Tensor types that are non recursive types.""" __slots__ = ("_dtype", "denotation") def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None: self._dtype = dtype self.denotation = denotation @property def dtype(self) -> _enums.DataType: return self._dtype @dtype.setter def dtype(self, value: _enums.DataType) -> None: self._dtype = value @property def elem_type(self) -> _enums.DataType: """Return the element type of the tensor type.""" return self.dtype def __hash__(self) -> int: return hash(repr(self)) def __eq__(self, other: object) -> bool: if self.__class__ is not other.__class__: return False return self.dtype == other.dtype # type: ignore[attr-defined] def __repr__(self) -> str: # Remove "Type" from name for display short_name = self.__class__.__name__[:-4] return f"{short_name}({self.dtype!r})" class TensorType(_TensorTypeBase): """A type that represents a tensor.""" def __str__(self) -> str: return f"{self.dtype}" class SparseTensorType(_TensorTypeBase): """A type that represents a sparse tensor.""" class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable): """Base for recursive types like Optional and Sequence.""" __slots__ = ("_elem_type", "denotation") def __init__( self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None ) -> None: self._elem_type = elem_type self.denotation = denotation @property def dtype(self) -> _enums.DataType: return self._elem_type.dtype @dtype.setter def dtype(self, value: _enums.DataType) -> None: self._elem_type.dtype = value @property def elem_type(self) -> _protocols.TypeProtocol: return self._elem_type def __hash__(self) -> int: return hash(repr(self)) def __eq__(self, other: object) -> bool: if not isinstance(other, _RecursiveTypeBase): return False if self.__class__ != other.__class__: return False # Recursively compare the type of the elements return self.elem_type == other.elem_type def __repr__(self) -> str: # Remove "Type" from name for display short_name = self.__class__.__name__[:-4] return f"{short_name}({self.elem_type!r})" class SequenceType(_RecursiveTypeBase): """A type that represents a sequence of elements.""" class OptionalType(_RecursiveTypeBase): """A type that represents an optional element.""" class Value(_protocols.ValueProtocol, _display.PrettyPrintable): """IR Value. A value is a named entity that can be used to represent an input or output of a graph, a function, or a node. The information it stores generalizes over ``ValueInfoProto`` in the ONNX specification. A :class:`Value` is always not owned or owned by exactly one node. When the value is not owned, it must be an input of a graph or a function. ``producer`` and ``index`` are ``None``. When the value is owned by a node, it is an output of the node. The node that produces the value can be accessed with :meth:`producer`. The index of the output of the node that produces the value can be accessed with :meth:`index`. To find all the nodes that use this value as an input, call :meth:`uses`. To check if the value is an is an input, output or initializer of a graph, use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`. Use :meth:`graph` to get the graph that owns the value. """ __slots__ = ( "_const_value", "_graph", "_index", "_is_graph_input", "_is_graph_output", "_is_initializer", "_metadata", "_metadata_props", "_name", "_producer", "_shape", "_type", "_uses", "doc_string", ) def __init__( self, producer: Node | None = None, *, index: int | None = None, name: str | None = None, shape: Shape | None = None, type: _protocols.TypeProtocol | None = None, doc_string: str | None = None, const_value: _protocols.TensorProtocol | None = None, ) -> None: """Initialize a value. Args: producer: The node that produces the value. It can be ``None`` when the value is initialized first than its producer. index: The index of the output of the defining node. name: The name of the value. shape: The shape of the value. type: The type of the value. doc_string: The documentation string. const_value: The constant tensor if the value is constant. """ self._producer: Node | None = producer self._index: int | None = index self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = None self._name: str | None = name self._shape: Shape | None = shape self._type: _protocols.TypeProtocol | None = type # TODO(justinchuby): Handle initialization when a const value is provided # We can get shape and type information from the const value self._const_value = const_value # Use a collection of (Node, int) to store uses. This is needed # because a single use can use the same value multiple times. # Use a dictionary to preserve insertion order so that the visiting order is deterministic self._uses: dict[Usage, None] = {} self.doc_string = doc_string # The graph this value belongs to. It is set *only* when the value is added as # a graph input, output or initializer. # The four properties can only be set by the Graph class (_GraphIO and GraphInitializers). self._graph: Graph | None = None self._is_graph_input: bool = False self._is_graph_output: bool = False self._is_initializer: bool = False def __repr__(self) -> str: value_name = self.name if self.name else "anonymous:" + str(id(self)) type_text = f", type={self.type!r}" if self.type is not None else "" shape_text = f", shape={self.shape!r}" if self.shape is not None else "" producer = self.producer() if producer is None: producer_text = "" elif producer.name is not None: producer_text = f", producer='{producer.name}'" else: producer_text = f", producer=anonymous_node:{id(producer)}" index_text = f", index={self.index()}" if self.index() is not None else "" const_value_text = self._constant_tensor_part() if const_value_text: const_value_text = f", const_value={const_value_text}" return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})" def __str__(self) -> str: value_name = self.name if self.name is not None else "anonymous:" + str(id(self)) shape_text = str(self.shape) if self.shape is not None else "?" type_text = str(self.type) if self.type is not None else "?" # Quote the name because in reality the names can have invalid characters # that make them hard to read return ( f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}" ) def _constant_tensor_part(self) -> str: """Display string for the constant tensor attached to str of Value.""" if self.const_value is not None: # Only display when the const value is small if self.const_value.size <= 10: return f"{{{self.const_value}}}" else: return f"{{{self.const_value.__class__.__name__}(...)}}" return "" @property def graph(self) -> Graph | None: """Return the graph that defines this value. When the value is an input/output/initializer of a graph, the owning graph is that graph. When the value is an output of a node, the owning graph is the graph that the node belongs to. When the value is not owned by any graph, it returns ``None``. """ if self._graph is not None: return self._graph if self._producer is not None: return self._producer.graph return None def _owned_by_graph(self) -> bool: """Return True if the value is owned by a graph.""" result = self._is_graph_input or self._is_graph_output or self._is_initializer if result: assert self._graph is not None return result
[docs] def producer(self) -> Node | None: """The node that produces this value. When producer is ``None``, the value does not belong to a node, and is typically a graph input or an initializer. You can use :meth:`graph`` to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph. """ return self._producer
[docs] def consumers(self) -> Sequence[Node]: """Return the nodes (deduplicated) that consume this value.""" return tuple({usage.node: None for usage in self._uses})
[docs] def index(self) -> int | None: """The index of the output of the defining node.""" return self._index
[docs] def uses(self) -> Collection[Usage]: """Return a set of uses of the value. The set contains tuples of ``(Node, index)`` where the index is the index of the input of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. """ # Create a tuple for the collection so that iteration on will will not # be affected when the usage changes during graph mutation. # This adds a small overhead but is better a user experience than # having users call tuple(). return tuple(self._uses)
def _add_usage(self, use: Node, index: int) -> None: """Add a usage of this value. This is an internal method. It should only be called by the Node class. """ self._uses[Usage(use, index)] = None def _remove_usage(self, use: Node, index: int) -> None: """Remove a node from the uses of this value. This is an internal method. It should only be called by the Node class. """ self._uses.pop(Usage(use, index)) @property def name(self) -> str | None: return self._name @name.setter def name(self, value: str | None) -> None: if self._const_value is not None: self._const_value.name = value self._name = value @property def type(self) -> _protocols.TypeProtocol | None: """The type of the tensor. Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``. To obtain the data type of the tensor, use ``type.dtype`` or conveniently :attr:`dtype`. """ return self._type @type.setter def type(self, value: _protocols.TypeProtocol | None) -> None: self._type = value @property def dtype(self) -> _enums.DataType | None: """The data type of the tensor.""" if self._type is None: return None return self._type.dtype @dtype.setter def dtype(self, value: _enums.DataType) -> None: """Set the data type of the tensor. If the type is not set, it will be initialized to a new TensorType. To set the type as other types like ``SequenceType``, initialize the type then set :attr:`type` instead. """ if self._type is None: self._type = TensorType(value) else: self._type.dtype = value @property def shape(self) -> Shape | None: return self._shape @shape.setter def shape(self, value: Shape | None) -> None: if value is None: self._shape = None return if isinstance(value, Shape): self._shape = value return raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'") @property def const_value( self, ) -> _protocols.TensorProtocol | None: """A concrete value. The value can be backed by different raw data types, such as numpy arrays. The only guarantee is that it conforms TensorProtocol. """ return self._const_value @const_value.setter def const_value( self, value: _protocols.TensorProtocol | None, ) -> None: if onnx_ir.DEBUG: if value is not None and not isinstance(value, _protocols.TensorProtocol): raise TypeError( f"Expected value to be a TensorProtocol or None, got '{type(value)}'" ) self._const_value = value @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ if self._metadata is None: self._metadata = _metadata.MetadataStore() return self._metadata @property def metadata_props(self) -> dict[str, str]: if self._metadata_props is None: self._metadata_props = {} return self._metadata_props
[docs] def is_graph_input(self) -> bool: """Whether the value is an input of a graph.""" return self._is_graph_input
[docs] def is_graph_output(self) -> bool: """Whether the value is an output of a graph.""" return self._is_graph_output
[docs] def is_initializer(self) -> bool: """Whether the value is an initializer of a graph.""" return self._is_initializer
def Input( name: str | None = None, shape: Shape | None = None, type: _protocols.TypeProtocol | None = None, doc_string: str | None = None, ) -> Value: """Create an input of a Graph or a Function. This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``. """ # NOTE: The function name is capitalized to maintain API backward compatibility. return Value(name=name, shape=shape, type=type, doc_string=doc_string) def _check_node_safe_to_remove( node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value] ) -> None: """Check if a node is safe to remove. 1. It checks to make sure there are no users of the node that are not to be removed before removing it. 2. It checks the node does not contribute to any graph outputs. This check is typically O(1) assuming the number of uses of the node is small Args: node: The node to check. to_remove: A set of nodes that are to be removed. This set is used to check if the node is still being used by other nodes that are not to be removed. graph_outputs: A set of values that are outputs of the graph. Raises: ValueError: If the node does not belong to this graph or if there are users of the node. ValueError: If the node is still being used by other nodes not to be removed. """ for output in node.outputs: if output in graph_outputs: raise ValueError( f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True." ) uses_not_to_remove = [user for user, _ in output.uses() if user not in to_remove] if uses_not_to_remove: raise ValueError( f"Output value '{output!r}' is still being used by other nodes that are not to be " f"removed. All of its users that is not being removed: {uses_not_to_remove!r}. " "Please make sure these nodes are no longer using the output value." ) class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable): """IR Graph. Graph represents a computation graph. In addition to the ONNX specification specified fields, it also contains a mapping of :attr:`opset_imports`. This allows different subgraphs to import different opsets. It is the responsibility of the deserializer to reconcile the different opsets. The `nodes` are not guaranteed to be topologically sorted. But the iteration order should be deterministic across different runs. It is the responsibility of the user to maintain a topological order of the nodes. Note that there is not a ``node`` attribute in the Graph. The Graph can be seen as a Sequence of nodes and should be used as such. For example, to obtain all nodes as a list, call ``list(graph)``. Attributes: name: The name of the graph. inputs: The input values of the graph. outputs: The output values of the graph. initializers: The initializers in the graph. doc_string: Documentation string. opset_imports: Opsets imported by the graph. metadata_props: Metadata that will be serialized to the ONNX file. meta: Metadata store for graph transform passes. """ __slots__ = ( "_doc_string", "_initializers", "_inputs", "_metadata", "_metadata_props", "_name_authority", "_nodes", "_opset_imports", "_outputs", "name", ) def __init__( self, inputs: Sequence[Value], outputs: Sequence[Value], *, nodes: Iterable[Node], initializers: Sequence[Value] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, metadata_props: dict[str, str] | None = None, ): self.name = name # Private fields that are not to be accessed by any other classes self._inputs = _graph_containers.GraphInputs(self, inputs) self._outputs = _graph_containers.GraphOutputs(self, outputs) self._initializers = _graph_containers.GraphInitializers( self, {initializer.name: initializer for initializer in initializers} ) self._doc_string = doc_string self._opset_imports = opset_imports or {} self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet() # Be sure the initialize the name authority before extending the nodes # because it is used to name the nodes and their outputs self._name_authority = _name_authority.NameAuthority() # TODO(justinchuby): Trigger again if inputs or initializers are modified. self._set_input_and_initializer_value_names_into_name_authority() # Call self.extend not self._nodes.extend so the graph reference is added to the nodes self.extend(nodes) @property def inputs(self) -> MutableSequence[Value]: return self._inputs @property def outputs(self) -> MutableSequence[Value]: return self._outputs @property def initializers(self) -> _graph_containers.GraphInitializers: """The initializers of the graph as a ``MutableMapping[str, Value]``. The keys are the names of the initializers. The values are the :class:`Value` objects. This property additionally supports the ``add`` method, which takes a :class:`Value` and adds it to the initializers if it is not already present. .. note:: When setting an initializer with ``graph.initializers[key] = value``, if the value does not have a name, it will be assigned ``key`` as its name. """ return self._initializers
[docs] def register_initializer(self, value: Value) -> None: """Register an initializer to the graph. This is a convenience method to register an initializer to the graph with checks. Args: value: The :class:`Value` to register as an initializer of the graph. It must have its ``.const_value`` set. Raises: ValueError: If a value of the same name that is not this value is already registered. ValueError: If the value does not have a name. ValueError: If the initializer is produced by a node. ValueError: If the value does not have its ``.const_value`` set. """ if not value.name: raise ValueError(f"Initializer must have a name: {value!r}") if value.name in self._initializers: if self._initializers[value.name] is not value: raise ValueError( f"Initializer '{value.name}' is already registered, but" " it is not the same object: existing={self._initializers[value.name]!r}," f" new={value!r}" ) if value.const_value is None: raise ValueError( f"Value '{value!r}' must have its const_value set to be an initializer." ) self._initializers.add(value)
@property def doc_string(self) -> str | None: return self._doc_string @doc_string.setter def doc_string(self, value: str | None) -> None: self._doc_string = value @property def opset_imports(self) -> dict[str, int]: return self._opset_imports @typing.overload def __getitem__(self, index: int) -> Node: ... @typing.overload def __getitem__(self, index: slice) -> Sequence[Node]: ... def __getitem__(self, index): return self._nodes[index] def __len__(self) -> int: return len(self._nodes) def __iter__(self) -> Iterator[Node]: return iter(self._nodes) def __reversed__(self) -> Iterator[Node]: return reversed(self._nodes) def _set_input_and_initializer_value_names_into_name_authority(self): for value in self.inputs: self._name_authority.register_or_name_value(value) for value in self.initializers.values(): self._name_authority.register_or_name_value(value) def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: """Set the graph reference for the node and assign names to it and its outputs if they don't have one.""" if node.graph is not None and node.graph is not self: raise ValueError( f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()." ) # Give the node and its output values names if they don't not have one self._name_authority.register_or_name_node(node) for value in node._outputs: # pylint: disable=protected-access self._name_authority.register_or_name_value(value) node.graph = self return node
[docs] def node(self, index_or_name: int | str, /) -> Node: """Get a node by index or name. This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1). .. note:: If you need repeated random access, consider turning it into a list with ``list(graph)`` . Or a dictionary for repeated access by name: ``{node.name for node in graph}`` . When a name is provided and if there are multiple nodes with the same name, the first node with the name is returned. Args: index_or_name: The index or name of the node. Returns: The node if found. Raises: IndexError: If the index is out of range. ValueError: If the node with the given name is not found. """ # NOTE: This is a method specific to Graph, not required by the protocol unless proven if isinstance(index_or_name, int): return self[index_or_name] for node in self: if node.name == index_or_name: return node raise ValueError(f"Node with name '{index_or_name}' not found.")
[docs] def num_nodes(self) -> int: """Get the number of nodes in the graph in O(1) time. Note that this method returns the number of nodes this graph directly contains. It does not count nodes in subgraphs. This is an alias for ``len(graph)``. Use this if you prefer a more descriptive name for readability. """ # NOTE: This is a method specific to Graph, not required by the protocol unless proven return len(self)
# Mutation methods
[docs] def append(self, node: Node, /) -> None: """Append a node to the graph in O(1) time. Unique names will be assigned to the node and its values if any name is ``None``. Args: node: The node to append. Raises: ValueError: If the node belongs to another graph. """ self._set_node_graph_to_self_and_assign_names(node) self._nodes.append(node)
[docs] def extend(self, nodes: Iterable[Node], /) -> None: """Extend the graph with the given nodes in O(#new_nodes) time. Unique names will be assigned to the node and its values if any name is ``None``. Args: nodes: The nodes to extend the graph with. Raises: ValueError: If any node belongs to another graph. """ nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes] self._nodes.extend(nodes)
[docs] def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: """Remove nodes from the graph in O(#num of nodes to remove) time. If any errors are raise, to ensure the graph is not left in an inconsistent state, the graph is not modified. Args: nodes: The node to remove. safe: If True, performs the following actions before removal: 1. It checks to make sure there are no users of the node that are not to be removed before removing it. 2. It checks the node does not contribute to any graph outputs. 3. It removes references to all inputs so it is no longer a user of other nodes. Raises: ValueError: If any node to remove does not belong to this graph. ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. """ if not isinstance(nodes, Iterable): nodes_set: AbstractSet[Node] = {nodes} else: nodes_set = frozenset(nodes) graph_outputs = frozenset(self.outputs) for node in nodes_set: if node.graph is not self: raise ValueError(f"The node '{node!r}' does not belong to this graph.") if safe: # Check 1, 2 _check_node_safe_to_remove(node, nodes_set, graph_outputs) for node in nodes_set: if safe: # 3. Detach from all inputs so that it is no longer a user of other nodes for i in range(len(node.inputs)): node.replace_input_with(i, None) # Set attributes to remove the node from this graph node.graph = None self._nodes.remove(node)
[docs] def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes after the given node in O(#new_nodes) time. Unique names will be assigned to the node and its values if any name is ``None``. Args: node: The node to insert after. new_nodes: The new nodes to insert. Raises: ValueError: If any node belongs to another graph. """ if isinstance(new_nodes, Node): new_nodes = (new_nodes,) new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] self._nodes.insert_after(node, new_nodes)
[docs] def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes before the given node in O(#new_nodes) time. Unique names will be assigned to the node and its values if any name is ``None``. Args: node: The node to insert before. new_nodes: The new nodes to insert. Raises: ValueError: If any node belongs to another graph. """ if isinstance(new_nodes, Node): new_nodes = (new_nodes,) new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] self._nodes.insert_before(node, new_nodes)
[docs] def sort(self) -> None: """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time. This sort is stable. It preserves the original order as much as possible. Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort Raises: ValueError: If the graph contains a cycle, making topological sorting impossible. """ # Obtain all nodes from the graph and its subgraphs for sorting nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self)) # Store the sorted nodes of each subgraph sorted_nodes_by_graph: dict[Graph, list[Node]] = { graph: [] for graph in {node.graph for node in nodes if node.graph is not None} } # TODO(justinchuby): Explain why we need to store direct predecessors and children and why # we only need to store the direct ones # The depth of a node is defined as the number of direct children it has node_depth: dict[Node, int] = dict.fromkeys(nodes, 0) # Direct predecessors of a node node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes} # Store the negative index of the nodes because heapq is a min heap and we # want to pop the node with largest index value first, effectively turning # it to a max heap neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)} def add_predecessor(child: Node, predecessor: Node | None) -> None: """Add a predecessor of a node, and increment the depth of the predecessor.""" if predecessor is None: return node_predecessors[child].append(predecessor) node_depth[predecessor] += 1 # 1. Build the direct predecessors of each node and the depth of each node # for sorting topologically using Kahn's algorithm. # Note that when a node contains graph attributes (aka. has subgraphs), # we consider all nodes in the subgraphs *predecessors* of this node. This # way we ensure the implicit dependencies of the subgraphs are captured # as predecessors of the node. for node in nodes: # All producers of input values are considered as direct predecessors. for input_value in node.inputs: if input_value is None: continue predecessor_node = input_value.producer() add_predecessor(node, predecessor_node) # All nodes in attribute graphs are considered as direct predecessors. for attr in node.attributes.values(): if not isinstance(attr, Attr): continue # A nice thing about this algorithm is that we only need to record # direct predecessors. This continues to be true even with subgraphs. # When a node in a subgraph (a) contains its own subgraphs (b), the # node in subgraphs (b) are guranteed to appear before the node # in (a). if attr.type == _enums.AttributeType.GRAPH: for predecessor_node in attr.value: add_predecessor(node, predecessor_node) elif attr.type == _enums.AttributeType.GRAPHS: for attribute_graph in attr.value: for predecessor_node in attribute_graph: add_predecessor(node, predecessor_node) # 2. Priority Queue: Track nodes with zero direct children in a priority queue, # using NEGATIVE original index for ordering. # This ensures nodes appearing LATER in the original order are processed EARLIER. # We get REVERSED topological order of each subgraph. priority_queue: list[tuple[int, Node]] = [ (neg_node_index[node], node) for node in nodes if node_depth[node] == 0 ] heapq.heapify(priority_queue) # 3. Topological Sort: num_of_sorted_nodes = 0 while priority_queue: # Pop the node with the most negative index and add it to the sorted nodes by subgraph. _, current_node = heapq.heappop(priority_queue) assert current_node.graph is not None sorted_nodes_by_graph[current_node.graph].append(current_node) num_of_sorted_nodes += 1 # Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue. for predecessor_node in node_predecessors[current_node]: node_depth[predecessor_node] -= 1 if node_depth[predecessor_node] == 0: heapq.heappush( priority_queue, (neg_node_index[predecessor_node], predecessor_node) ) # 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle. if num_of_sorted_nodes != len(nodes): raise ValueError("Graph contains a cycle, topological sort is not possible.") # 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order. for graph, sorted_nodes in sorted_nodes_by_graph.items(): # The graph container ensures all the nodes are unique so we can safely extend graph.extend(reversed(sorted_nodes))
# End of mutation methods @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ if self._metadata is None: self._metadata = _metadata.MetadataStore() return self._metadata @property def metadata_props(self) -> dict[str, str]: if self._metadata_props is None: self._metadata_props = {} return self._metadata_props def __str__(self) -> str: return _graph_str(self) def __repr__(self) -> str: return _graph_repr(self) def _graph_str(graph: Graph | GraphView) -> str: """Return a string representation of the graph.""" # TODO(justinchuby): Show docstrings and metadata inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) if initializers_text: initializers_text = ( "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," ) signature = f"""\ graph( name={graph.name or "anonymous_graph:" + str(id(graph))}, inputs=({textwrap.indent(inputs_text, " " * 8)} ), outputs=({textwrap.indent(outputs_text, " " * 8)} ),{textwrap.indent(initializers_text, " " * 4)} )""" node_count = len(graph) number_width = len(str(node_count)) node_lines = [] for i, node in enumerate(graph): node_name = node.name if node.name else f":anonymous_node:{id(node)}" node_text = f"# {node_name}\n{node}" indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) # Remove the leading spaces indented_node_text = indented_node_text.strip() node_lines.append(f"{i:>{number_width}} | {indented_node_text}") returns = ", ".join(str(x) for x in graph.outputs) body = ( "{\n" + textwrap.indent("\n".join(node_lines), " " * 4) + textwrap.indent(f"\nreturn {returns}", " " * 4) + "\n}" ) return f"{signature} {body}" def _graph_repr(graph: Graph | GraphView) -> str: """Return an repr string of the graph.""" inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) if initializers_text: initializers_text = ( "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," ) return f"""\ {graph.__class__.__name__}( name={graph.name or "anonymous_graph:" + str(id(graph))!r}, inputs=({textwrap.indent(inputs_text, " " * 8)} ), outputs=({textwrap.indent(outputs_text, " " * 8)} ),{textwrap.indent(initializers_text, " " * 4)} len()={len(graph)} )""" class GraphView(Sequence[Node], _display.PrettyPrintable): """A read-only view on a graph. The GraphView is useful for analysis of a subgraph. It can be initialized with a subset of nodes from a :class:`Graph`. Creating GraphView does not change the ownership of the nodes, and so it is possible to create multiple GraphViews that contain the same nodes. If the underlying nodes / connections are mutated, the mutation will be reflected in all views as well. The graph view can be serialized to ONNX:: graph_proto = ir.serde.serialize_graph(graph_view) It can also be used to create a model:: model = ir.Model(graph_view, ir_version=8) model_proto = ir.serde.serialize_model(model) The model created with a GraphView will have a fixed topology, and its graph will remain read-only as a GraphView. No copying will be done during the initialization process. Attributes: name: The name of the graph. inputs: The input values of the graph. outputs: The output values of the graph. initializers: The initializers in the graph. doc_string: Documentation string. opset_imports: Opsets imported by the graph. metadata_props: Metadata that will be serialized to the ONNX file. meta: Metadata store for graph transform passes. """ __slots__ = ( "_metadata", "_metadata_props", "doc_string", "initializers", "inputs", "name", "nodes", "opset_imports", "outputs", ) def __init__( self, inputs: Sequence[Value], outputs: Sequence[Value], *, nodes: Iterable[Node], initializers: Sequence[Value] = (), doc_string: str | None = None, opset_imports: dict[str, int] | None = None, name: str | None = None, metadata_props: dict[str, str] | None = None, ): self.name = name self.inputs = tuple(inputs) self.outputs = tuple(outputs) self.initializers = {initializer.name: initializer for initializer in initializers} self.doc_string = doc_string self.opset_imports = opset_imports or {} self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props self._nodes: tuple[Node, ...] = tuple(nodes) @typing.overload def __getitem__(self, index: int) -> Node: ... @typing.overload def __getitem__(self, index: slice) -> Sequence[Node]: ... def __getitem__(self, index): return self._nodes[index] def __len__(self) -> int: return len(self._nodes) def __iter__(self) -> Iterator[Node]: return iter(self._nodes) def __reversed__(self) -> Iterator[Node]: return reversed(self._nodes) @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ if self._metadata is None: self._metadata = _metadata.MetadataStore() return self._metadata @property def metadata_props(self) -> dict[str, str]: if self._metadata_props is None: self._metadata_props = {} return self._metadata_props def __str__(self) -> str: return _graph_str(self) def __repr__(self) -> str: return _graph_repr(self) class Model(_protocols.ModelProtocol, _display.PrettyPrintable): __slots__ = ( "_functions", "_metadata", "_metadata_props", "doc_string", "domain", "graph", "ir_version", "model_version", "producer_name", "producer_version", ) """IR Model. A model is a container for a graph and metadata. Attributes: graph: The graph of the model. ir_version: The version of the IR. producer_name: The name of the producer. producer_version: The version of the producer. domain: The domain of the model. model_version: The version of the model. doc_string: Documentation string. functions: The functions defined in the model. metadata_props: Metadata. """ def __init__( self, graph: Graph, *, ir_version: int, producer_name: str | None = None, producer_version: str | None = None, domain: str | None = None, model_version: int | None = None, doc_string: str | None = None, functions: Sequence[Function] = (), metadata_props: dict[str, str] | None = None, ) -> None: self.graph: Graph = graph self.ir_version = ir_version self.producer_name = producer_name self.producer_version = producer_version self.domain = domain self.model_version = model_version self.doc_string = doc_string self._functions = {func.identifier(): func for func in functions} self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = metadata_props @property def functions(self) -> dict[_protocols.OperatorIdentifier, Function]: return self._functions @property def opset_imports(self) -> dict[str, int]: return self.graph.opset_imports @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ if self._metadata is None: self._metadata = _metadata.MetadataStore() return self._metadata @property def metadata_props(self) -> dict[str, str]: if self._metadata_props is None: self._metadata_props = {} return self._metadata_props def __str__(self) -> str: # TODO(justinchuby): Show docstrings and metadata signature = f"""\ < ir_version={self.ir_version!r}, opset_imports={self.opset_imports!r}, producer_name={self.producer_name!r}, producer_version={self.producer_version!r}, domain={self.domain!r}, model_version={self.model_version!r}, >""" graph_text = str(self.graph) functions_text = "\n\n".join(str(func) for func in self.functions.values()) return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" def __repr__(self) -> str: return f"""\ Model( ir_version={self.ir_version!r}, opset_imports={self.opset_imports!r}, producer_name={self.producer_name!r}, producer_version={self.producer_version!r}, domain={self.domain!r}, model_version={self.model_version!r}, functions={self.functions!r}, graph={textwrap.indent(repr(self.graph), " " * 4).strip()} )"""
[docs] def graphs(self) -> Iterable[Graph]: """Get all graphs and subgraphs in the model. This is a convenience method to traverse the model. Consider using `onnx_ir.traversal.RecursiveGraphIterator` for more advanced traversals on nodes. """ # NOTE(justinchuby): Given # (1) how useful the method is # (2) I couldn't find an appropriate name for it in `traversal.py` # (3) Users familiar with onnxruntime optimization tools expect this method # I created this method as a core method instead of an iterator in # `traversal.py`. seen_graphs: set[Graph] = set() for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph): if node.graph is not None and node.graph not in seen_graphs: seen_graphs.add(node.graph) yield node.graph
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable): """IR functions. Like a graph, a function can have nodes that are not topologically sorted. It is the responsibility of the user to maintain a topological order of the nodes. Note that there is not a ``node`` attribute in the Function. The Function can be seen as a Sequence of nodes and should be used as such. For example, to obtain all nodes as a list, call ``list(function)``. Attributes: name: The function name. domain: The domain this function is defined in. overload: The overload name when the function is overloaded. inputs: The input values of the function. attributes: The attributes this function defines. outputs: The output values of the function. opset_imports: Opsets imported by the function. doc_string: Documentation string. meta: Metadata store for graph transform passes. metadata_props: Metadata that will be serialized to the ONNX file. """ __slots__ = ( "_attributes", "_domain", "_graph", "_name", "_overload", ) def __init__( self, domain: str, name: str, overload: str = "", *, # Ensure the inputs and outputs of the function belong to a graph # and not from an outer scope graph: Graph, attributes: Iterable[Attr] | Mapping[str, Attr], ) -> None: self._domain = domain self._name = name self._overload = overload self._graph = graph if isinstance(attributes, Mapping): attributes = tuple(attributes.values()) self._attributes = _graph_containers.Attributes(attributes)
[docs] def identifier(self) -> _protocols.OperatorIdentifier: return self.domain, self.name, self.overload
@property def name(self) -> str: return self._name @name.setter def name(self, value: str) -> None: self._name = value @property def domain(self) -> str: return self._domain @domain.setter def domain(self, value: str) -> None: self._domain = _normalize_domain(value) @property def overload(self) -> str: return self._overload @overload.setter def overload(self, value: str) -> None: self._overload = value @property def inputs(self) -> MutableSequence[Value]: return self._graph.inputs @property def outputs(self) -> MutableSequence[Value]: return self._graph.outputs @property def attributes(self) -> _graph_containers.Attributes: return self._attributes @typing.overload def __getitem__(self, index: int) -> Node: ... @typing.overload def __getitem__(self, index: slice) -> Sequence[Node]: ... def __getitem__(self, index): return self._graph.__getitem__(index) def __len__(self) -> int: return self._graph.__len__() def __iter__(self) -> Iterator[Node]: return self._graph.__iter__() def __reversed__(self) -> Iterator[Node]: return self._graph.__reversed__() @property def doc_string(self) -> str | None: return self._graph.doc_string @doc_string.setter def doc_string(self, value: str | None) -> None: self._graph.doc_string = value @property def opset_imports(self) -> dict[str, int]: return self._graph.opset_imports @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. Write to the :attr:`metadata_props` if you would like the metadata to be serialized to the ONNX proto. """ return self._graph.meta @property def metadata_props(self) -> dict[str, str]: return self._graph.metadata_props # Mutation methods
[docs] def append(self, node: Node, /) -> None: """Append a node to the function in O(1) time.""" self._graph.append(node)
[docs] def extend(self, nodes: Iterable[Node], /) -> None: """Extend the function with the given nodes in O(#new_nodes) time.""" self._graph.extend(nodes)
[docs] def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: """Remove nodes from the graph in O(#num of nodes) time. If any errors are raise, to ensure the graph is not left in an inconsistent state, the graph is not modified. Args: nodes: The node to remove. safe: If True, performs the following actions before removal: 1. It checks to make sure there are no users of the node that are not to be removed before removing it. 2. It checks the node does not contribute to any graph outputs. 3. It removes references to all inputs so it is no longer a user of other nodes. Raises: ValueError: If any node to remove does not belong to this graph. ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. """ self._graph.remove(nodes, safe=safe)
[docs] def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes after the given node in O(#new_nodes) time.""" self._graph.insert_after(node, new_nodes)
[docs] def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: """Insert new nodes before the given node in O(#new_nodes) time.""" self._graph.insert_before(node, new_nodes)
[docs] def sort(self) -> None: """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.""" self._graph.sort()
# End of mutation methods def __str__(self) -> str: full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "") inputs_text = ",\n".join(str(x) for x in self.inputs) outputs_text = ",\n".join(str(x) for x in self.outputs) attributes_text = ",\n".join( f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is not None) for attr in self.attributes.values() ) if attributes_text: attributes_text = ( "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}" ) signature = f"""\ < opset_imports={self.opset_imports!r}, > def {full_name}( inputs=( {textwrap.indent(inputs_text, " " * 8)} ),{textwrap.indent(attributes_text, " " * 4)} outputs=( {textwrap.indent(outputs_text, " " * 8)} ), )""" node_count = len(self) number_width = len(str(node_count)) node_lines = [] for i, node in enumerate(self): node_name = node.name if node.name else f":anonymous_node:{id(node)}" node_text = f"# {node_name}\n{node}" indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) # Remove the leading spaces indented_node_text = indented_node_text.strip() node_lines.append(f"{i:>{number_width}} | {indented_node_text}") returns = ", ".join(str(x) for x in self.outputs) body = ( "{\n" + textwrap.indent("\n".join(node_lines), " " * 4) + textwrap.indent(f"\nreturn {returns}", " " * 4) + "\n}" ) return f"{signature} {body}" def __repr__(self) -> str: return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})" class Attr( _protocols.AttributeProtocol, _protocols.ReferenceAttributeProtocol, _display.PrettyPrintable, ): """Base class for ONNX attributes or references.""" __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string") def __init__( self, name: str, type: _enums.AttributeType, value: Any, ref_attr_name: str | None = None, *, doc_string: str | None = None, ) -> None: self._name = name self._type = type self._value = value self._ref_attr_name = ref_attr_name self.doc_string = doc_string @property def name(self) -> str: return self._name @name.setter def name(self, value: str) -> None: self._name = value @property def type(self) -> _enums.AttributeType: return self._type @property def value(self) -> Any: return self._value @property def ref_attr_name(self) -> str | None: return self._ref_attr_name
[docs] def is_ref(self) -> bool: """Check if this attribute is a reference attribute.""" return self.ref_attr_name is not None
def __eq__(self, other: object) -> bool: if not isinstance(other, _protocols.AttributeProtocol): return False if self.name != other.name: return False if self.type != other.type: return False if self.value != other.value: return False if self.doc_string != other.doc_string: return False return True def __str__(self) -> str: if self.is_ref(): return f"@{self.ref_attr_name}" if self.type == _enums.AttributeType.GRAPH: return textwrap.indent("\n" + str(self.value), " " * 4) return str(self.value) def __repr__(self) -> str: if self.is_ref(): return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})" return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})" # Well typed getters
[docs] def as_float(self) -> float: """Get the attribute value as a float.""" if self.type != _enums.AttributeType.FLOAT: raise TypeError( f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}" ) # Do not use isinstance check because it may prevent np.float32 etc. from being used return float(self.value)
[docs] def as_int(self) -> int: """Get the attribute value as an int.""" if self.type != _enums.AttributeType.INT: raise TypeError( f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}" ) # Do not use isinstance check because it may prevent np.int32 etc. from being used return int(self.value)
[docs] def as_string(self) -> str: """Get the attribute value as a string.""" if self.type != _enums.AttributeType.STRING: raise TypeError( f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}" ) if not isinstance(self.value, str): raise TypeError(f"Value of attribute '{self!r}' is not a string.") return self.value
[docs] def as_tensor(self) -> _protocols.TensorProtocol: """Get the attribute value as a tensor.""" if self.type != _enums.AttributeType.TENSOR: raise TypeError( f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}" ) if not isinstance(self.value, _protocols.TensorProtocol): raise TypeError(f"Value of attribute '{self!r}' is not a tensor.") return self.value
[docs] def as_graph(self) -> Graph: """Get the attribute value as a graph.""" if self.type != _enums.AttributeType.GRAPH: raise TypeError( f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}" ) if not isinstance(self.value, Graph): raise TypeError(f"Value of attribute '{self!r}' is not a graph.") return self.value
[docs] def as_floats(self) -> Sequence[float]: """Get the attribute value as a sequence of floats.""" if self.type != _enums.AttributeType.FLOATS: raise TypeError( f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}" ) if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used # Create a copy of the list to prevent mutation return [float(v) for v in self.value]
[docs] def as_ints(self) -> Sequence[int]: """Get the attribute value as a sequence of ints.""" if self.type != _enums.AttributeType.INTS: raise TypeError( f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}" ) if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used # Create a copy of the list to prevent mutation return list(self.value)
[docs] def as_strings(self) -> Sequence[str]: """Get the attribute value as a sequence of strings.""" if self.type != _enums.AttributeType.STRINGS: raise TypeError( f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}" ) if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") if onnx_ir.DEBUG: if not all(isinstance(x, str) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.") # Create a copy of the list to prevent mutation return list(self.value)
[docs] def as_tensors(self) -> Sequence[_protocols.TensorProtocol]: """Get the attribute value as a sequence of tensors.""" if self.type != _enums.AttributeType.TENSORS: raise TypeError( f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}" ) if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") if onnx_ir.DEBUG: if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.") # Create a copy of the list to prevent mutation return list(self.value)
[docs] def as_graphs(self) -> Sequence[Graph]: """Get the attribute value as a sequence of graphs.""" if self.type != _enums.AttributeType.GRAPHS: raise TypeError( f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}" ) if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") if onnx_ir.DEBUG: if not all(isinstance(x, Graph) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.") # Create a copy of the list to prevent mutation return list(self.value)
# NOTE: The following functions are just for convenience def RefAttr( name: str, ref_attr_name: str, type: _enums.AttributeType, doc_string: str | None = None, ) -> Attr: """Create a reference attribute. Args: name: The name of the attribute. type: The type of the attribute. ref_attr_name: The name of the referenced attribute. doc_string: Documentation string. Returns: A reference attribute. """ # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string) def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr: """Create a float attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.FLOAT, value, doc_string=doc_string, ) def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr: """Create an int attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.INT, value, doc_string=doc_string, ) def AttrString(name: str, value: str, doc_string: str | None = None) -> Attr: """Create a str attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.STRING, value, doc_string=doc_string, ) def AttrTensor( name: str, value: _protocols.TensorProtocol, doc_string: str | None = None ) -> Attr: """Create a tensor attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.TENSOR, value, doc_string=doc_string, ) def AttrGraph(name: str, value: Graph, doc_string: str | None = None) -> Attr: """Create a graph attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.GRAPH, value, doc_string=doc_string, ) def AttrFloat32s(name: str, value: Sequence[float], doc_string: str | None = None) -> Attr: """Create a float sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.FLOATS, value, doc_string=doc_string, ) def AttrInt64s(name: str, value: Sequence[int], doc_string: str | None = None) -> Attr: """Create an int sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.INTS, value, doc_string=doc_string, ) def AttrStrings(name: str, value: Sequence[str], doc_string: str | None = None) -> Attr: """Create a string sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.STRINGS, value, doc_string=doc_string, ) def AttrTensors( name: str, value: Sequence[_protocols.TensorProtocol], doc_string: str | None = None ) -> Attr: """Create a tensor sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.TENSORS, value, doc_string=doc_string, ) def AttrGraphs(name: str, value: Sequence[Graph], doc_string: str | None = None) -> Attr: """Create a graph sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.GRAPHS, value, doc_string=doc_string, ) # NOTE: SparseTensor should be a sparse tensor proto def AttrSparseTensor( name: str, value: _protocols.SparseTensorProtocol, doc_string: str | None = None ) -> Attr: """Create a sparse tensor attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.SPARSE_TENSOR, value, doc_string=doc_string, ) def AttrSparseTensors( name: str, value: Sequence[_protocols.SparseTensorProtocol], doc_string: str | None = None ) -> Attr: """Create a sparse tensor sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.SPARSE_TENSORS, value, doc_string=doc_string, ) @dataclasses.dataclass class TypeAndShape: """Type and shape. Useful for constructing a type proto. """ type: _protocols.TypeProtocol | None shape: Shape | None def AttrTypeProto(name: str, value: TypeAndShape, doc_string: str | None = None) -> Attr: """Create a type attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.TYPE_PROTO, value, doc_string=doc_string, ) def AttrTypeProtos( name: str, value: Sequence[TypeAndShape], doc_string: str | None = None ) -> Attr: """Create a type sequence attribute.""" # NOTE: The function name is capitalized to maintain API backward compatibility. return Attr( name, _enums.AttributeType.TYPE_PROTOS, value, doc_string=doc_string, )