Source code for onnx_ir._enums

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""ONNX IR enums that matches the ONNX spec."""

from __future__ import annotations

import enum

import ml_dtypes
import numpy as np


class AttributeType(enum.IntEnum):
    """Enum for the types of ONNX attributes."""

    UNDEFINED = 0
    FLOAT = 1
    INT = 2
    STRING = 3
    TENSOR = 4
    GRAPH = 5
    FLOATS = 6
    INTS = 7
    STRINGS = 8
    TENSORS = 9
    GRAPHS = 10
    SPARSE_TENSOR = 11
    SPARSE_TENSORS = 12
    TYPE_PROTO = 13
    TYPE_PROTOS = 14

    def __repr__(self) -> str:
        return self.name

    def __str__(self) -> str:
        return self.__repr__()


class DataType(enum.IntEnum):
    """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""

    # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
    # but we should stick to the names used in the ONNX spec for consistency.
    UNDEFINED = 0
    FLOAT = 1
    UINT8 = 2
    INT8 = 3
    UINT16 = 4
    INT16 = 5
    INT32 = 6
    INT64 = 7
    STRING = 8
    BOOL = 9
    FLOAT16 = 10
    DOUBLE = 11
    UINT32 = 12
    UINT64 = 13
    COMPLEX64 = 14
    COMPLEX128 = 15
    BFLOAT16 = 16
    FLOAT8E4M3FN = 17
    FLOAT8E4M3FNUZ = 18
    FLOAT8E5M2 = 19
    FLOAT8E5M2FNUZ = 20
    UINT4 = 21
    INT4 = 22
    FLOAT4E2M1 = 23

[docs] @classmethod def from_numpy(cls, dtype: np.dtype) -> DataType: """Returns the ONNX data type for the numpy dtype. Raises: TypeError: If the data type is not supported by ONNX. """ if dtype in _NP_TYPE_TO_DATA_TYPE: return cls(_NP_TYPE_TO_DATA_TYPE[dtype]) if np.issubdtype(dtype, np.str_): return DataType.STRING # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18) # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py if hasattr(dtype, "names"): if dtype.names == ("bfloat16",): return DataType.BFLOAT16 if dtype.names == ("e4m3fn",): return DataType.FLOAT8E4M3FN if dtype.names == ("e4m3fnuz",): return DataType.FLOAT8E4M3FNUZ if dtype.names == ("e5m2",): return DataType.FLOAT8E5M2 if dtype.names == ("e5m2fnuz",): return DataType.FLOAT8E5M2FNUZ if dtype.names == ("uint4",): return DataType.UINT4 if dtype.names == ("int4",): return DataType.INT4 if dtype.names == ("float4e2m1",): return DataType.FLOAT4E2M1 raise TypeError(f"Unsupported numpy data type: {dtype}")
[docs] @classmethod def from_short_name(cls, short_name: str) -> DataType: """Returns the ONNX data type for the short name. Raises: TypeError: If the short name is not available for the data type. """ if short_name not in _SHORT_NAME_TO_DATA_TYPE: raise TypeError(f"Unknown short name: {short_name}") return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
@property def itemsize(self) -> float: """Returns the size of the data type in bytes.""" return _ITEMSIZE_MAP[self]
[docs] def numpy(self) -> np.dtype: """Returns the numpy dtype for the ONNX data type. Raises: TypeError: If the data type is not supported by numpy. """ if self not in _DATA_TYPE_TO_NP_TYPE: raise TypeError(f"Numpy does not support ONNX data type: {self}") return _DATA_TYPE_TO_NP_TYPE[self]
[docs] def short_name(self) -> str: """Returns the short name of the data type. The short name is a string that is used to represent the data type in a more compact form. For example, the short name for `DataType.FLOAT` is "f32". To get the corresponding data type back, call ``from_short_name`` on a string. Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py Raises: TypeError: If the short name is not available for the data type. """ if self not in _DATA_TYPE_TO_SHORT_NAME: raise TypeError(f"Short name not available for ONNX data type: {self}") return _DATA_TYPE_TO_SHORT_NAME[self]
[docs] def is_floating_point(self) -> bool: """Returns True if the data type is a floating point type.""" return self in { DataType.FLOAT, DataType.FLOAT16, DataType.DOUBLE, DataType.BFLOAT16, DataType.FLOAT8E4M3FN, DataType.FLOAT8E4M3FNUZ, DataType.FLOAT8E5M2, DataType.FLOAT8E5M2FNUZ, DataType.FLOAT4E2M1, }
def __repr__(self) -> str: return self.name def __str__(self) -> str: return self.__repr__() _ITEMSIZE_MAP = { DataType.FLOAT: 4, DataType.UINT8: 1, DataType.INT8: 1, DataType.UINT16: 2, DataType.INT16: 2, DataType.INT32: 4, DataType.INT64: 8, DataType.STRING: 1, DataType.BOOL: 1, DataType.FLOAT16: 2, DataType.DOUBLE: 8, DataType.UINT32: 4, DataType.UINT64: 8, DataType.COMPLEX64: 8, DataType.COMPLEX128: 16, DataType.BFLOAT16: 2, DataType.FLOAT8E4M3FN: 1, DataType.FLOAT8E4M3FNUZ: 1, DataType.FLOAT8E5M2: 1, DataType.FLOAT8E5M2FNUZ: 1, DataType.UINT4: 0.5, DataType.INT4: 0.5, DataType.FLOAT4E2M1: 0.5, } # We use ml_dtypes to support dtypes that are not in numpy. _NP_TYPE_TO_DATA_TYPE = { np.dtype("bool"): DataType.BOOL, np.dtype("complex128"): DataType.COMPLEX128, np.dtype("complex64"): DataType.COMPLEX64, np.dtype("float16"): DataType.FLOAT16, np.dtype("float32"): DataType.FLOAT, np.dtype("float64"): DataType.DOUBLE, np.dtype("int16"): DataType.INT16, np.dtype("int32"): DataType.INT32, np.dtype("int64"): DataType.INT64, np.dtype("int8"): DataType.INT8, np.dtype("object"): DataType.STRING, np.dtype("uint16"): DataType.UINT16, np.dtype("uint32"): DataType.UINT32, np.dtype("uint64"): DataType.UINT64, np.dtype("uint8"): DataType.UINT8, np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16, np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN, np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ, np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2, np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ, np.dtype(ml_dtypes.int4): DataType.INT4, np.dtype(ml_dtypes.uint4): DataType.UINT4, } # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE _NP_TYPE_TO_DATA_TYPE.update( {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} if hasattr(ml_dtypes, "float4_e2m1fn") else {} ) # ONNX DataType to Numpy dtype. _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} _DATA_TYPE_TO_SHORT_NAME = { DataType.UNDEFINED: "undefined", DataType.BFLOAT16: "bf16", DataType.DOUBLE: "f64", DataType.FLOAT: "f32", DataType.FLOAT16: "f16", DataType.FLOAT8E4M3FN: "f8e4m3fn", DataType.FLOAT8E5M2: "f8e5m2", DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz", DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz", DataType.FLOAT4E2M1: "f4e2m1", DataType.COMPLEX64: "c64", DataType.COMPLEX128: "c128", DataType.INT4: "i4", DataType.INT8: "i8", DataType.INT16: "i16", DataType.INT32: "i32", DataType.INT64: "i64", DataType.BOOL: "b8", DataType.UINT4: "u4", DataType.UINT8: "u8", DataType.UINT16: "u16", DataType.UINT32: "u32", DataType.UINT64: "u64", DataType.STRING: "s", } _SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}