Source code for onnx_ir.tensor_adapters

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Compatible adapters implementing the TensorProtocol interface for various framework tensor types.

This module provides public classes that implement the :class:`onnx_ir.TensorProtocol`
interface for various tensor types from popular deep learning frameworks.

You can use these classes to create tensors and use them in the IR graph like any other tensor.

Example::
    import torch
    import onnx_ir as ir

    # Create a PyTorch tensor
    torch_tensor = torch.tensor([1, 2, 3])

    # Wrap the PyTorch tensor in a TorchTensor object
    ir_tensor = ir.tensor_adapters.TorchTensor(torch_tensor)

    # Use the IR tensor in the graph
    attr = ir.AttrTensor("x", ir_tensor)
    print(attr)
"""

# pylint: disable=import-outside-toplevel

# NOTE: DO NOT import any framework-specific modules here in the global namespace.

from __future__ import annotations

__all__ = [
    "from_torch_dtype",
    "to_torch_dtype",
    "TorchTensor",
]

import ctypes
from typing import TYPE_CHECKING, Any

import numpy.typing as npt

import onnx_ir as ir
from onnx_ir import _core

if TYPE_CHECKING:
    import torch


_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None


[docs] def from_torch_dtype(dtype: torch.dtype) -> ir.DataType: """Convert a PyTorch dtype to an ONNX IR DataType.""" global _TORCH_DTYPE_TO_ONNX if _TORCH_DTYPE_TO_ONNX is None: import torch _TORCH_DTYPE_TO_ONNX = { torch.bfloat16: ir.DataType.BFLOAT16, torch.bool: ir.DataType.BOOL, torch.complex128: ir.DataType.COMPLEX128, torch.complex64: ir.DataType.COMPLEX64, torch.float16: ir.DataType.FLOAT16, torch.float32: ir.DataType.FLOAT, torch.float64: ir.DataType.DOUBLE, torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, torch.float8_e5m2: ir.DataType.FLOAT8E5M2, torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, torch.int16: ir.DataType.INT16, torch.int32: ir.DataType.INT32, torch.int64: ir.DataType.INT64, torch.int8: ir.DataType.INT8, torch.uint8: ir.DataType.UINT8, torch.uint16: ir.DataType.UINT16, torch.uint32: ir.DataType.UINT32, torch.uint64: ir.DataType.UINT64, } if dtype not in _TORCH_DTYPE_TO_ONNX: raise TypeError( f"Unsupported PyTorch dtype '{dtype}'. " "Please use a supported dtype from the list: " f"{list(_TORCH_DTYPE_TO_ONNX.keys())}" ) return _TORCH_DTYPE_TO_ONNX[dtype]
[docs] def to_torch_dtype(dtype: ir.DataType) -> torch.dtype: """Convert an ONNX IR DataType to a PyTorch dtype.""" global _ONNX_DTYPE_TO_TORCH if _ONNX_DTYPE_TO_TORCH is None: import torch _ONNX_DTYPE_TO_TORCH = { ir.DataType.BFLOAT16: torch.bfloat16, ir.DataType.BOOL: torch.bool, ir.DataType.COMPLEX128: torch.complex128, ir.DataType.COMPLEX64: torch.complex64, ir.DataType.FLOAT16: torch.float16, ir.DataType.FLOAT: torch.float32, ir.DataType.DOUBLE: torch.float64, ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn, ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, ir.DataType.FLOAT8E5M2: torch.float8_e5m2, ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, ir.DataType.INT16: torch.int16, ir.DataType.INT32: torch.int32, ir.DataType.INT64: torch.int64, ir.DataType.INT8: torch.int8, ir.DataType.UINT8: torch.uint8, ir.DataType.UINT16: torch.uint16, ir.DataType.UINT32: torch.uint32, ir.DataType.UINT64: torch.uint64, } if dtype not in _ONNX_DTYPE_TO_TORCH: raise TypeError( f"Unsupported conversion from ONNX dtype '{dtype}' to torch. " "Please use a supported dtype from the list: " f"{list(_ONNX_DTYPE_TO_TORCH.keys())}" ) return _ONNX_DTYPE_TO_TORCH[dtype]
[docs] class TorchTensor(_core.Tensor): def __init__( self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None ): # Pass the tensor as the raw data to ir.Tensor's constructor super().__init__( tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string )
[docs] def numpy(self) -> npt.NDArray: import torch self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) return self.raw.numpy(force=True)
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: del copy # Unused, but needed for the signature if dtype is None: return self.numpy() return self.numpy().__array__(dtype)
[docs] def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array import torch._subclasses.fake_tensor with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access # Disable any fake mode so calling detach() etc. will return a real tensor tensor = self.raw.detach().cpu().contiguous() if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access raise TypeError( f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " "with a tensor backed by real data using ONNXProgram.apply_weights() " "or save the model without initializers by setting include_initializers=False." ) return bytes( (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( tensor.data_ptr() ) )