Source code for onnx.mapping

# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import warnings
from typing import Any, Dict, NamedTuple, Union, cast

import numpy as np

from onnx import OptionalProto, SequenceProto, TensorProto


[docs] class TensorDtypeMap(NamedTuple): np_dtype: np.dtype storage_dtype: int name: str
# tensor_dtype: (numpy type, storage type, string name) TENSOR_TYPE_MAP = { int(TensorProto.FLOAT): TensorDtypeMap( np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT" ), int(TensorProto.UINT8): TensorDtypeMap( np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8" ), int(TensorProto.INT8): TensorDtypeMap( np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8" ), int(TensorProto.UINT16): TensorDtypeMap( np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16" ), int(TensorProto.INT16): TensorDtypeMap( np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16" ), int(TensorProto.INT32): TensorDtypeMap( np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32" ), int(TensorProto.INT64): TensorDtypeMap( np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64" ), int(TensorProto.BOOL): TensorDtypeMap( np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL" ), int(TensorProto.FLOAT16): TensorDtypeMap( np.dtype("float16"), int(TensorProto.UINT16), "TensorProto.FLOAT16" ), # Native numpy does not support bfloat16 so now use float32. int(TensorProto.BFLOAT16): TensorDtypeMap( np.dtype("float32"), int(TensorProto.UINT16), "TensorProto.BFLOAT16" ), int(TensorProto.DOUBLE): TensorDtypeMap( np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE" ), int(TensorProto.COMPLEX64): TensorDtypeMap( np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64" ), int(TensorProto.COMPLEX128): TensorDtypeMap( np.dtype("complex128"), int(TensorProto.DOUBLE), "TensorProto.COMPLEX128" ), int(TensorProto.UINT32): TensorDtypeMap( np.dtype("uint32"), int(TensorProto.UINT32), "TensorProto.UINT32" ), int(TensorProto.UINT64): TensorDtypeMap( np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64" ), int(TensorProto.STRING): TensorDtypeMap( np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING" ), # Native numpy does not support float8 types, so now use float32 for these types. int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap( np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FN" ), int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap( np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FNUZ" ), int(TensorProto.FLOAT8E5M2): TensorDtypeMap( np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2" ), int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap( np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2FNUZ" ), # Native numpy does not support uint4/int4 so now use uint8/int8 for these types. int(TensorProto.UINT4): TensorDtypeMap( np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT4" ), int(TensorProto.INT4): TensorDtypeMap( np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT4" ), } class DeprecatedWarningDict(dict): # type: ignore def __init__( self, dictionary: dict[int, int | str | np.dtype], original_function: str, future_function: str = "", ) -> None: super().__init__(dictionary) self._origin_function = original_function self._future_function = future_function def __eq__(self, other: object) -> bool: if not isinstance(other, DeprecatedWarningDict): return False return ( self._origin_function == other._origin_function and self._future_function == other._future_function ) def __getitem__(self, key: int | str | np.dtype) -> Any: if not self._future_function: warnings.warn( str( f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." "To silence this warning, please simply use if-else statement to get the corresponding value." ), DeprecationWarning, stacklevel=2, ) else: warnings.warn( str( f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." f"To silence this warning, please use `helper.{self._future_function}` instead." ), DeprecationWarning, stacklevel=2, ) return super().__getitem__(key) # This map is used for converting TensorProto values into numpy arrays TENSOR_TYPE_TO_NP_TYPE = DeprecatedWarningDict( {tensor_dtype: value.np_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items()}, "TENSOR_TYPE_TO_NP_TYPE", "tensor_dtype_to_np_dtype", ) # This is only used to get keys into STORAGE_TENSOR_TYPE_TO_FIELD. # TODO(https://github.com/onnx/onnx/issues/4554): Move these variables into _mapping.py TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE = DeprecatedWarningDict( { tensor_dtype: value.storage_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items() }, "TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE", "tensor_dtype_to_storage_tensor_dtype", ) # NP_TYPE_TO_TENSOR_TYPE will be eventually removed in the future # and _NP_TYPE_TO_TENSOR_TYPE will only be used internally _NP_TYPE_TO_TENSOR_TYPE = { v: k for k, v in TENSOR_TYPE_TO_NP_TYPE.items() if k not in ( TensorProto.BFLOAT16, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ, TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4, ) } # Currently native numpy does not support bfloat16 so TensorProto.BFLOAT16 is ignored for now # Numpy float32 array is only reversed to TensorProto.FLOAT NP_TYPE_TO_TENSOR_TYPE = DeprecatedWarningDict( cast(Dict[int, Union[int, str, Any]], _NP_TYPE_TO_TENSOR_TYPE), "NP_TYPE_TO_TENSOR_TYPE", "np_dtype_to_tensor_dtype", ) # STORAGE_TENSOR_TYPE_TO_FIELD will be eventually removed in the future # and _STORAGE_TENSOR_TYPE_TO_FIELD will only be used internally _STORAGE_TENSOR_TYPE_TO_FIELD = { int(TensorProto.FLOAT): "float_data", int(TensorProto.INT32): "int32_data", int(TensorProto.INT64): "int64_data", int(TensorProto.UINT8): "int32_data", int(TensorProto.UINT16): "int32_data", int(TensorProto.DOUBLE): "double_data", int(TensorProto.COMPLEX64): "float_data", int(TensorProto.COMPLEX128): "double_data", int(TensorProto.UINT32): "uint64_data", int(TensorProto.UINT64): "uint64_data", int(TensorProto.STRING): "string_data", int(TensorProto.BOOL): "int32_data", } STORAGE_TENSOR_TYPE_TO_FIELD = DeprecatedWarningDict( cast(Dict[int, Union[int, str, Any]], _STORAGE_TENSOR_TYPE_TO_FIELD), "STORAGE_TENSOR_TYPE_TO_FIELD", ) # This map will be removed and there is no replacement for it STORAGE_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( { int(SequenceProto.TENSOR): "tensor_values", int(SequenceProto.SPARSE_TENSOR): "sparse_tensor_values", int(SequenceProto.SEQUENCE): "sequence_values", int(SequenceProto.MAP): "map_values", int(OptionalProto.OPTIONAL): "optional_value", }, "STORAGE_ELEMENT_TYPE_TO_FIELD", ) # This map will be removed and there is no replacement for it OPTIONAL_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( { int(OptionalProto.TENSOR): "tensor_value", int(OptionalProto.SPARSE_TENSOR): "sparse_tensor_value", int(OptionalProto.SEQUENCE): "sequence_value", int(OptionalProto.MAP): "map_value", int(OptionalProto.OPTIONAL): "optional_value", }, "OPTIONAL_ELEMENT_TYPE_TO_FIELD", )