# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import collections.abc
import numbers
import struct
from cmath import isnan
from typing import (
Any,
Callable,
Dict,
KeysView,
List,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
import google.protobuf.message
import numpy as np
import onnx._custom_element_types as custom_np_types
from onnx import (
IR_VERSION,
AttributeProto,
FunctionProto,
GraphProto,
MapProto,
ModelProto,
NodeProto,
OperatorSetIdProto,
OptionalProto,
SequenceProto,
SparseTensorProto,
TensorProto,
TensorShapeProto,
TrainingInfoProto,
TypeProto,
ValueInfoProto,
defs,
mapping,
subbyte,
)
VersionRowType = Union[Tuple[str, int, int, int], Tuple[str, int, int, int, int]]
VersionTableType = List[VersionRowType]
AssignmentBindingType = List[Tuple[str, str]]
# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
# Both must be updated whenever a new version of ONNX is released.
VERSION_TABLE: VersionTableType = [
# Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
("1.0", 3, 1, 1),
("1.1", 3, 5, 1),
("1.1.2", 3, 6, 1),
("1.2", 3, 7, 1),
("1.3", 3, 8, 1),
("1.4.1", 4, 9, 1),
("1.5.0", 5, 10, 1),
("1.6.0", 6, 11, 2),
("1.7.0", 7, 12, 2, 1),
("1.8.0", 7, 13, 2, 1),
("1.8.1", 7, 13, 2, 1),
("1.9.0", 7, 14, 2, 1),
("1.10.0", 8, 15, 2, 1),
("1.10.1", 8, 15, 2, 1),
("1.10.2", 8, 15, 2, 1),
("1.11.0", 8, 16, 3, 1),
("1.12.0", 8, 17, 3, 1),
("1.13.0", 8, 18, 3, 1),
("1.13.1", 8, 18, 3, 1),
("1.14.0", 9, 19, 3, 1),
("1.14.1", 9, 19, 3, 1),
("1.15.0", 9, 20, 4, 1),
("1.16.0", 10, 21, 5, 1),
("1.17.0", 10, 22, 5, 1),
]
VersionMapType = Dict[Tuple[str, int], int]
[docs]
def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
"""Create a map from (opset-domain, opset-version) to ir-version from above table."""
result: VersionMapType = {}
def process(release_version: str, ir_version: int, *args: Any) -> None:
del release_version # Unused
for pair in zip(["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args):
if pair not in result:
result[pair] = ir_version
if pair[0] == "ai.onnx.training":
result["ai.onnx.preview.training", pair[1]] = ir_version
for row in table:
process(*row)
return result
OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE)
[docs]
def find_min_ir_version_for(
opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False
) -> int:
"""Given list of opset ids, determine minimum IR version required.
Args:
opsetidlist: A sequence of OperatorSetIdProto.
ignore_unknown: If True, ignore unknown domain and return default minimum
version for that domain.
Returns:
The minimum IR version required (integer)
"""
default_min_version = 3
def find_min(domain: str | None, version: int) -> int:
key = (domain or "ai.onnx", version)
if key in OP_SET_ID_VERSION_MAP:
return OP_SET_ID_VERSION_MAP[key]
if ignore_unknown:
return default_min_version
raise ValueError("Unsupported opset-version.")
if opsetidlist:
return max(find_min(x.domain, x.version) for x in opsetidlist)
return default_min_version # if no opsets specified
[docs]
def make_node(
op_type: str,
inputs: Sequence[str],
outputs: Sequence[str],
name: str | None = None,
doc_string: str | None = None,
domain: str | None = None,
overload: str | None = None,
**kwargs: Any,
) -> NodeProto:
"""Construct a NodeProto.
Args:
op_type (string): The name of the operator to construct
inputs (list of string): list of input names
outputs (list of string): list of output names
name (string, default None): optional unique identifier for NodeProto
doc_string (string, default None): optional documentation string for NodeProto
domain (string, default None): optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
overload (string, default None): optional field, used to
resolve calls to model-local functions
**kwargs (dict): the attributes of the node. The acceptable values
are documented in :func:`make_attribute`.
Returns:
NodeProto
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if overload is not None:
node.overload = overload
if kwargs:
node.attribute.extend(
make_attribute(key, value)
for key, value in sorted(kwargs.items())
if value is not None
)
return node
[docs]
def make_operatorsetid(
domain: str,
version: int,
) -> OperatorSetIdProto:
"""Construct an OperatorSetIdProto.
Args:
domain (string): The domain of the operator set id
version (integer): Version of operator set id
Returns:
OperatorSetIdProto
"""
operatorsetid = OperatorSetIdProto()
operatorsetid.domain = domain
operatorsetid.version = version
return operatorsetid
[docs]
def make_graph(
nodes: Sequence[NodeProto],
name: str,
inputs: Sequence[ValueInfoProto],
outputs: Sequence[ValueInfoProto],
initializer: Sequence[TensorProto] | None = None,
doc_string: str | None = None,
value_info: Sequence[ValueInfoProto] | None = None,
sparse_initializer: Sequence[SparseTensorProto] | None = None,
) -> GraphProto:
"""Construct a GraphProto
Args:
nodes: list of NodeProto
name (string): graph name
inputs: list of ValueInfoProto
outputs: list of ValueInfoProto
initializer: list of TensorProto
doc_string (string): graph documentation
value_info: list of ValueInfoProto
sparse_initializer: list of SparseTensorProto
Returns:
GraphProto
"""
if initializer is None:
initializer = []
if sparse_initializer is None:
sparse_initializer = []
if value_info is None:
value_info = []
graph = GraphProto()
graph.node.extend(nodes)
graph.name = name
graph.input.extend(inputs)
graph.output.extend(outputs)
graph.initializer.extend(initializer)
graph.sparse_initializer.extend(sparse_initializer)
graph.value_info.extend(value_info)
if doc_string:
graph.doc_string = doc_string
return graph
[docs]
def make_opsetid(domain: str, version: int) -> OperatorSetIdProto:
"""Construct an OperatorSetIdProto.
Args:
domain (string): The domain of the operator set id
version (integer): Version of operator set id
Returns:
OperatorSetIdProto
"""
opsetid = OperatorSetIdProto()
opsetid.domain = domain
opsetid.version = version
return opsetid
[docs]
def make_function(
domain: str,
fname: str,
inputs: Sequence[str],
outputs: Sequence[str],
nodes: Sequence[NodeProto],
opset_imports: Sequence[OperatorSetIdProto],
attributes: Sequence[str] | None = None,
attribute_protos: Sequence[AttributeProto] | None = None,
doc_string: str | None = None,
overload: str | None = None,
value_info: Sequence[ValueInfoProto] | None = None,
) -> FunctionProto:
if attributes is None:
attributes = []
if attribute_protos is None:
attribute_protos = []
if value_info is None:
value_info = []
f = FunctionProto()
f.domain = domain
f.name = fname
f.input.extend(inputs)
f.output.extend(outputs)
f.node.extend(nodes)
f.opset_import.extend(opset_imports)
f.attribute.extend(attributes)
f.attribute_proto.extend(attribute_protos)
if doc_string:
f.doc_string = doc_string
if overload is not None:
f.overload = overload
f.value_info.extend(value_info)
return f
[docs]
def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
"""Construct a ModelProto
Args:
graph (GraphProto): *make_graph* returns
**kwargs: any attribute to add to the returned instance
Returns:
ModelProto
"""
model = ModelProto()
# Touch model.ir_version so it is stored as the version from which it is
# generated.
model.ir_version = IR_VERSION
model.graph.CopyFrom(graph)
opset_imports: Sequence[OperatorSetIdProto] | None = kwargs.pop(
"opset_imports", None
)
if opset_imports is not None:
model.opset_import.extend(opset_imports)
else:
# Default import
imp = model.opset_import.add()
imp.version = defs.onnx_opset_version()
functions: Sequence[FunctionProto] | None = kwargs.pop("functions", None)
if functions is not None:
model.functions.extend(functions)
for k, v in kwargs.items():
# TODO: Does this work with repeated fields?
setattr(model, k, v)
return model
# An extension of make_model that infers an IR_VERSION for the model,
# if not specified, using a best-effort-basis.
[docs]
def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto:
ir_version_field = "ir_version"
if ir_version_field not in kwargs:
opset_imports_field = "opset_imports"
imports = kwargs.get(opset_imports_field, [])
kwargs[ir_version_field] = find_min_ir_version_for(imports)
return make_model(graph, **kwargs)
[docs]
def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
set_metadata_props(model, dict_value)
[docs]
def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
return [
(ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) # type: ignore[misc]
for i in range(len(ca) * 2)
]
# convert a float32 value to a bfloat16 (as int)
# By default, this conversion rounds-to-nearest-even and supports NaN
# Setting `truncate` to True enables a simpler conversion. In this mode the
# conversion is performed by simply dropping the 2 least significant bytes of
# the significand. In this mode an error of up to 1 bit may be introduced and
# preservation of NaN values is not be guaranteed.
[docs]
def float32_to_bfloat16(fval: float, truncate: bool = False) -> int:
ival = int.from_bytes(struct.pack("<f", fval), "little")
if truncate:
return ival >> 16
# NaN requires at least 1 significand bit set
if isnan(fval):
return 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000
# drop bottom 16-bits
# round remaining bits using round-to-nearest-even
rounded = ((ival >> 16) & 1) + 0x7FFF
return (ival + rounded) >> 16
[docs]
def float32_to_float8e4m3( # noqa: PLR0911
fval: float,
scale: float = 1.0,
fn: bool = True,
uz: bool = False,
saturate: bool = True,
) -> int:
"""Convert a float32 value to a float8, e4m3 (as int).
See :ref:`onnx-detail-float8` for technical details.
Args:
fval: float to convert
scale: scale, divide *fval* by *scale* before casting it
fn: no infinite values
uz: no negative zero
saturate: if True, any value out of range included inf becomes
the maximum value, otherwise, it becomes NaN. The
description of operator Cast fully describes the
differences.
Returns:
converted float
"""
if not fn:
raise NotImplementedError(
"float32_to_float8e4m3 not implemented with fn=False."
)
x = fval / scale
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little")
ret = (b & 0x80000000) >> 24 # sign
if uz:
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
return 0x80
if np.isinf(x):
if saturate:
return ret | 127
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
if e < 116: # noqa: PLR2004
ret = 0
elif e < 120: # noqa: PLR2004
# denormalized number
ex = e - 119
if ex >= -2: # noqa: PLR2004
ret |= 1 << (2 + ex)
ret |= m >> (21 - ex)
elif m > 0:
ret |= 1
else:
ret = 0
mask = 1 << (20 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 135: # noqa: PLR2004
# normalized number
ex = e - 119 # 127 - 8
if ex == 0:
ret |= 0x4
ret |= m >> 21
else:
ret |= ex << 3
ret |= m >> 20
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
if (ret & 0x7F) < 0x7F: # noqa: PLR2004
# rounding
ret += 1
elif not saturate:
return 0x80
elif saturate:
ret |= 0x7F # 01111110
else:
ret = 0x80
return int(ret)
else:
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
return 0x7F | ret
if np.isinf(x):
if saturate:
return ret | 126
return 0x7F | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
if e != 0:
if e < 117: # noqa: PLR2004
pass
elif e < 121: # noqa: PLR2004
# denormalized number
ex = e - 120
if ex >= -2: # noqa: PLR2004
ret |= 1 << (2 + ex)
ret |= m >> (21 - ex)
elif m > 0:
ret |= 1
mask = 1 << (20 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 136: # noqa: PLR2004
# normalized number
ex = e - 120
if ex == 0:
ret |= 0x4
ret |= m >> 21
else:
ret |= ex << 3
ret |= m >> 20
if (ret & 0x7F) == 0x7F: # noqa: PLR2004
ret &= 0xFE
if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)):
if (ret & 0x7F) < 0x7E: # noqa: PLR2004
# rounding
ret += 1
elif not saturate:
ret |= 0x7F
elif saturate:
ret |= 126 # 01111110
else:
ret |= 0x7F
return int(ret)
[docs]
def float32_to_float8e5m2( # noqa: PLR0911
fval: float,
scale: float = 1.0,
fn: bool = False,
uz: bool = False,
saturate: bool = True,
) -> int:
"""Convert a float32 value to a float8, e5m2 (as int).
Args:
fval: float to convert
scale: scale, divide *fval* by *scale* before casting it
fn: no infinite values
uz: no negative zero
saturate: if True, any value out of range included inf becomes
the maximum value, otherwise, it becomes NaN. The
description of operator Cast fully describes the
differences.
Returns:
converted float
"""
x = fval / scale
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little")
ret = (b & 0x80000000) >> 24 # sign
if fn and uz:
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
return 0x80
if (b & 0x7FFFFFFF) == 0x7F800000: # noqa: PLR2004
# inf
if saturate:
return ret | 0x7F
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
if e < 109: # noqa: PLR2004
ret = 0
elif e < 112: # noqa: PLR2004
# denormalized number
ex = e - 111
if ex >= -1:
ret |= 1 << (1 + ex)
ret |= m >> (22 - ex)
elif m > 0:
ret |= 1
else:
ret = 0
mask = 1 << (21 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 143: # noqa: PLR2004
# normalized number
ex = e - 111
ret |= ex << 2
ret |= m >> 21
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
if (ret & 0x7F) < 0x7F: # noqa: PLR2004
# rounding
ret += 1
elif not saturate:
ret = 0x80
elif e == 255 and m == 0: # inf # noqa: PLR2004
ret = 0x80
elif saturate:
ret |= 0x7F # last possible number
else:
ret = 0x80
return int(ret)
elif not fn and not uz:
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004
return 0x7F | ret
if np.isinf(x):
if saturate:
return 0x7B | ret
return 0x7C | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa
if e != 0:
if e < 110: # noqa: PLR2004
pass
elif e < 113: # noqa: PLR2004
# denormalized number
ex = e - 112
if ex >= -1:
ret |= 1 << (1 + ex)
ret |= m >> (22 - ex)
elif m > 0:
ret |= 1
mask = 1 << (21 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 143: # noqa: PLR2004
# normalized number
ex = e - 112
ret |= ex << 2
ret |= m >> 21
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
if (ret & 0x7F) < 0x7B: # noqa: PLR2004
# rounding
ret += 1
elif saturate:
ret |= 0x7B
else:
ret |= 0x7C
elif saturate:
ret |= 0x7B
else:
ret |= 0x7C
return int(ret)
else:
raise NotImplementedError("fn and uz must be both False or True.")
[docs]
def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarray:
"""Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte.
See :ref:`onnx-detail-int4` for technical details.
Args:
array: array of float to convert and pack
signed: Whether the 4 bit variant is signed or unsigned
Returns:
Packed array with size `ceil(farray.size/2)` (single dimension).
"""
if not isinstance(array, np.ndarray):
array = np.asarray(array, dtype=np.float32)
array_flat = array.ravel()
is_odd_volume = np.prod(array.shape) % 2 == 1
if is_odd_volume:
array_flat = np.append(array_flat, np.array([0]))
def single_func(x, y) -> np.ndarray:
return subbyte.float32x2_to_4bitx2(x, y, signed)
func = np.frompyfunc(single_func, 2, 1)
arr: np.ndarray = func(array_flat[0::2], array_flat[1::2])
return arr.astype(np.uint8)
def pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray:
"""Convert an array of float32 value to float4e2m1 and pack every two concecutive elements in a byte.
See :ref:`onnx-detail-float4` for technical details.
Args:
array: array of float to convert and pack
Returns:
Packed array of float4e2m1 (as uint8) with size `ceil(farray.size/2)` (single dimension).
"""
if not isinstance(array, np.ndarray):
array = np.asarray(array, dtype=np.float32)
array_flat = array.ravel()
is_odd_volume = np.prod(array.shape) % 2 == 1
if is_odd_volume:
array_flat = np.append(array_flat, np.array([0]))
arr = subbyte.float32x2_to_float4e2m1x2(array_flat[0::2], array_flat[1::2])
return arr.astype(np.uint8)
[docs]
def make_tensor(
name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = False
) -> TensorProto:
"""Make a TensorProto with specified arguments. If raw is False, this
function will choose the corresponding proto field to store the
values based on data_type. If raw is True, use "raw_data" proto
field to store the values, and values should be of type bytes in
this case.
Args:
name (string): tensor name
data_type (int): a value such as onnx.TensorProto.FLOAT
dims (List[int]): shape
vals: values
raw (bool): if True, vals contains the serialized content of the tensor,
otherwise, vals should be a list of values of the type defined by *data_type*
Returns:
TensorProto
"""
tensor = TensorProto()
tensor.data_type = data_type
tensor.name = name
if data_type == TensorProto.STRING and raw:
raise TypeError("Can not use raw_data to store string type.")
np_dtype = tensor_dtype_to_np_dtype(data_type)
# Check number of vals specified equals tensor size
expected_size: float = 1
if raw:
# NumPy doesn't have BFLOAT16. TENSOR_TYPE_MAP maps it to float32, which has the wrong itemsize.
if data_type == TensorProto.BFLOAT16:
expected_size = 2
elif data_type in (
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
):
expected_size = 1
# NumPy doesn't have INT4/FP4. It is packed in couples to UINT8 buffers.
elif data_type in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1):
expected_size = 0.5
else:
expected_size = np_dtype.itemsize
if isinstance(vals, np.ndarray) and len(vals.shape) > 1:
vals = vals.flatten()
for d in dims:
expected_size *= d
if len(vals) != expected_size:
# padding of half a byte is acceptable for 4bit types
if not (
data_type in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1)
and len(vals) == expected_size + 0.5
):
raise ValueError(
f"Number of values does not match tensor's size. Expected {expected_size}, but it is {len(vals)}. "
)
if raw:
tensor.raw_data = vals
else:
if data_type in (TensorProto.COMPLEX64, TensorProto.COMPLEX128):
vals = split_complex_to_pairs(vals)
elif data_type == TensorProto.FLOAT16:
vals = (
np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist()
)
elif data_type in (
TensorProto.BFLOAT16,
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
):
fcast = {
TensorProto.BFLOAT16: float32_to_bfloat16,
TensorProto.FLOAT8E4M3FN: float32_to_float8e4m3,
TensorProto.FLOAT8E4M3FNUZ: lambda *args: float32_to_float8e4m3( # type: ignore[misc]
*args, uz=True
),
TensorProto.FLOAT8E5M2: float32_to_float8e5m2,
TensorProto.FLOAT8E5M2FNUZ: lambda *args: float32_to_float8e5m2( # type: ignore[misc]
*args, fn=True, uz=True
),
}[
data_type # type: ignore[index]
]
vals = list(
map( # type: ignore[call-overload]
fcast,
np.array(vals).astype(np_dtype).flatten().tolist(),
)
)
elif data_type in (
TensorProto.UINT4,
TensorProto.INT4,
):
signed = data_type == TensorProto.INT4
# Two packed 4-bit values must be represented as a single uint8 value.
# Therefore, pack_float32_to_4bit() sets the dtype of the output vals
# to uint8 regardless of the value of 'signed'. Using int8 would cause
# the size of int4 tensors to increase ~5x if the tensor contains negative values (due to
# the way negative values are serialized by protobuf).
vals = pack_float32_to_4bit(vals, signed=signed).flatten().tolist()
elif data_type == TensorProto.FLOAT4E2M1:
vals = pack_float32_to_float4e2m1(vals).flatten().tolist()
elif data_type == TensorProto.BOOL:
vals = np.array(vals).astype(int)
elif data_type == TensorProto.STRING:
vals = np.array(vals).astype(bytes)
field = tensor_dtype_to_field(data_type)
getattr(tensor, field).extend(vals)
tensor.dims.extend(dims)
return tensor
[docs]
def make_sparse_tensor(
values: TensorProto, indices: TensorProto, dims: Sequence[int]
) -> SparseTensorProto:
"""Construct a SparseTensorProto
Args:
values (TensorProto): the values
indices (TensorProto): the indices
dims: the shape
Returns:
SparseTensorProto
"""
sparse = SparseTensorProto()
sparse.values.CopyFrom(values)
sparse.indices.CopyFrom(indices)
sparse.dims.extend(dims)
return sparse
[docs]
def make_sequence(
name: str,
elem_type: SequenceProto.DataType,
values: Sequence[Any],
) -> SequenceProto:
"""Make a Sequence with specified value arguments."""
sequence = SequenceProto()
sequence.name = name
sequence.elem_type = elem_type
if elem_type == SequenceProto.UNDEFINED:
return sequence
attribute: Sequence | None = None
if elem_type == SequenceProto.TENSOR:
attribute = sequence.tensor_values
elif elem_type == SequenceProto.SPARSE_TENSOR:
attribute = sequence.sparse_tensor_values
elif elem_type == SequenceProto.SEQUENCE:
attribute = sequence.sequence_values
elif elem_type == SequenceProto.MAP:
attribute = sequence.map_values
elif elem_type == OptionalProto.OPTIONAL:
attribute = sequence.optional_values
else:
raise TypeError("The element type in the input sequence is not supported.")
attribute.extend(values)
return sequence
[docs]
def make_map(
name: str, key_type: int, keys: list[Any], values: SequenceProto
) -> MapProto:
"""Make a Map with specified key-value pair arguments.
Criteria for conversion:
- Keys and Values must have the same number of elements
- Every key in keys must be of the same type
- Every value in values must be of the same type
"""
map_proto = MapProto()
valid_key_int_types = [
TensorProto.INT8,
TensorProto.INT16,
TensorProto.INT32,
TensorProto.INT64,
TensorProto.UINT8,
TensorProto.UINT16,
TensorProto.UINT32,
TensorProto.UINT64,
]
map_proto.name = name
map_proto.key_type = key_type
if key_type == TensorProto.STRING:
map_proto.string_keys.extend(keys)
elif key_type in valid_key_int_types:
map_proto.keys.extend(keys)
map_proto.values.CopyFrom(values)
return map_proto
[docs]
def make_optional(
name: str,
elem_type: OptionalProto.DataType,
value: google.protobuf.message.Message | None,
) -> OptionalProto:
"""Make an Optional with specified value arguments."""
optional = OptionalProto()
optional.name = name
optional.elem_type = elem_type
if elem_type == OptionalProto.UNDEFINED:
return optional
attribute: google.protobuf.message.Message | None = None
if elem_type == OptionalProto.TENSOR:
attribute = optional.tensor_value
elif elem_type == OptionalProto.SPARSE_TENSOR:
attribute = optional.sparse_tensor_value
elif elem_type == OptionalProto.SEQUENCE:
attribute = optional.sequence_value
elif elem_type == OptionalProto.MAP:
attribute = optional.map_value
elif elem_type == OptionalProto.OPTIONAL:
attribute = optional.optional_value
else:
raise TypeError("The element type in the input optional is not supported.")
assert value is not None
attribute.CopyFrom(value) # type: ignore[arg-type]
return optional
def _to_bytes(value: str | bytes) -> bytes:
"""Coerce a string (or bytes) value into UTF-8 bytes."""
if isinstance(value, str):
return value.encode("utf-8")
return value
[docs]
def make_attribute(
key: str,
value: Any,
doc_string: str | None = None,
attr_type: int | None = None,
) -> AttributeProto:
"""Makes an AttributeProto based on the value type."""
attr = AttributeProto()
attr.name = key
if doc_string:
attr.doc_string = doc_string
# Singular cases
if isinstance(value, numbers.Integral):
attr.i = int(value)
attr.type = AttributeProto.INT
elif isinstance(value, numbers.Real):
attr.f = float(value)
attr.type = AttributeProto.FLOAT
elif isinstance(value, (str, bytes)):
# Encode strings into utf-8
attr.s = _to_bytes(value)
attr.type = AttributeProto.STRING
elif isinstance(value, TensorProto):
attr.t.CopyFrom(value)
attr.type = AttributeProto.TENSOR
elif isinstance(value, SparseTensorProto):
attr.sparse_tensor.CopyFrom(value)
attr.type = AttributeProto.SPARSE_TENSOR
elif isinstance(value, GraphProto):
attr.g.CopyFrom(value)
attr.type = AttributeProto.GRAPH
elif isinstance(value, TypeProto):
attr.tp.CopyFrom(value)
attr.type = AttributeProto.TYPE_PROTO
# Iterable cases
elif isinstance(value, collections.abc.Iterable):
value = list(value)
if len(value) == 0 and attr_type is None:
raise ValueError(
f"Could not infer attribute `{key}` type from empty iterator"
)
if attr_type is None:
types = {type(v) for v in value}
for exp_t, exp_enum in (
(numbers.Integral, AttributeProto.INTS),
(numbers.Real, AttributeProto.FLOATS),
((str, bytes), AttributeProto.STRINGS),
(TensorProto, AttributeProto.TENSORS),
(SparseTensorProto, AttributeProto.SPARSE_TENSORS),
(GraphProto, AttributeProto.GRAPHS),
(TypeProto, AttributeProto.TYPE_PROTOS),
):
if all(issubclass(t, exp_t) for t in types):
attr_type = exp_enum
break
if attr_type is None:
raise ValueError(
"Could not infer the attribute type from the elements of the passed Iterable value."
)
if attr_type == AttributeProto.INTS:
attr.ints.extend(value)
attr.type = AttributeProto.INTS
elif attr_type == AttributeProto.FLOATS:
attr.floats.extend(value)
attr.type = AttributeProto.FLOATS
elif attr_type == AttributeProto.STRINGS:
attr.strings.extend(_to_bytes(v) for v in value)
attr.type = AttributeProto.STRINGS
elif attr_type == AttributeProto.TENSORS:
attr.tensors.extend(value)
attr.type = AttributeProto.TENSORS
elif attr_type == AttributeProto.SPARSE_TENSORS:
attr.sparse_tensors.extend(value)
attr.type = AttributeProto.SPARSE_TENSORS
elif attr_type == AttributeProto.GRAPHS:
attr.graphs.extend(value)
attr.type = AttributeProto.GRAPHS
elif attr_type == AttributeProto.TYPE_PROTOS:
attr.type_protos.extend(value)
attr.type = AttributeProto.TYPE_PROTOS
else:
raise AssertionError() # Should not reach since `ValueError` must be raised in attr_type checking
else:
raise TypeError(f"'{value}' is not an accepted attribute value.")
if attr_type is not None and attr.type != attr_type:
raise TypeError(
f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})"
)
return attr
[docs]
def make_attribute_ref(
name: str, attr_type: AttributeProto.AttributeType, doc_string: str | None = None
) -> AttributeProto:
"""Make an AttributeProto holding a reference to the parent function's attribute of given name and type."""
attr = AttributeProto()
attr.name = name
attr.type = attr_type
if doc_string:
attr.doc_string = doc_string
return attr
[docs]
def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911
if attr.ref_attr_name:
raise ValueError(f"Cannot get value of reference attribute: {attr}")
if attr.type == AttributeProto.FLOAT:
return attr.f
if attr.type == AttributeProto.INT:
return attr.i
if attr.type == AttributeProto.STRING:
return attr.s
if attr.type == AttributeProto.TENSOR:
return attr.t
if attr.type == AttributeProto.SPARSE_TENSOR:
return attr.sparse_tensor
if attr.type == AttributeProto.GRAPH:
return attr.g
if attr.type == AttributeProto.TYPE_PROTO:
return attr.tp
if attr.type == AttributeProto.FLOATS:
return list(attr.floats)
if attr.type == AttributeProto.INTS:
return list(attr.ints)
if attr.type == AttributeProto.STRINGS:
return list(attr.strings)
if attr.type == AttributeProto.TENSORS:
return list(attr.tensors)
if attr.type == AttributeProto.SPARSE_TENSORS:
return list(attr.sparse_tensors)
if attr.type == AttributeProto.GRAPHS:
return list(attr.graphs)
if attr.type == AttributeProto.TYPE_PROTOS:
return list(attr.type_protos)
if attr.type == AttributeProto.UNDEFINED:
return None
raise ValueError(f"Unsupported ONNX attribute: {attr}")
[docs]
def get_node_attr_value(node: NodeProto, attr_name: str) -> Any:
matching = [x for x in node.attribute if x.name == attr_name]
if len(matching) > 1:
raise ValueError(f"Node has multiple attributes with name {attr_name}")
if len(matching) < 1:
raise ValueError(f"Node has no attribute with name {attr_name}")
return get_attribute_value(matching[0])
[docs]
def make_empty_tensor_value_info(name: str) -> ValueInfoProto:
value_info_proto = ValueInfoProto()
value_info_proto.name = name
return value_info_proto
[docs]
def make_tensor_type_proto(
elem_type: int,
shape: Sequence[str | int | None] | None,
shape_denotation: list[str] | None = None,
) -> TypeProto:
"""Makes a Tensor TypeProto based on the data type and shape."""
type_proto = TypeProto()
tensor_type_proto = type_proto.tensor_type
tensor_type_proto.elem_type = elem_type
tensor_shape_proto = tensor_type_proto.shape
if shape is not None:
# You might think this is a no-op (extending a normal Python
# list by [] certainly is), but protobuf lists work a little
# differently; if a field is never set, it is omitted from the
# resulting protobuf; a list that is explicitly set to be
# empty will get an (empty) entry in the protobuf. This
# difference is visible to our consumers, so make sure we emit
# an empty shape!
tensor_shape_proto.dim.extend([])
if shape_denotation and len(shape_denotation) != len(shape):
raise ValueError(
"Invalid shape_denotation. Must be of the same length as shape."
)
for i, d in enumerate(shape):
dim = tensor_shape_proto.dim.add()
if d is None:
pass
elif isinstance(d, int):
dim.dim_value = d
elif isinstance(d, str):
dim.dim_param = d
else:
raise ValueError(
f"Invalid item in shape: {d}. Needs to be of int or str."
)
if shape_denotation:
dim.denotation = shape_denotation[i]
return type_proto
[docs]
def make_tensor_value_info(
name: str,
elem_type: int,
shape: Sequence[str | int | None] | None,
doc_string: str = "",
shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
"""Makes a ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation)
value_info_proto.type.CopyFrom(tensor_type_proto)
return value_info_proto
[docs]
def make_sparse_tensor_type_proto(
elem_type: int,
shape: Sequence[str | int | None] | None,
shape_denotation: list[str] | None = None,
) -> TypeProto:
"""Makes a SparseTensor TypeProto based on the data type and shape."""
type_proto = TypeProto()
sparse_tensor_type_proto = type_proto.sparse_tensor_type
sparse_tensor_type_proto.elem_type = elem_type
sparse_tensor_shape_proto = sparse_tensor_type_proto.shape
if shape is not None:
# You might think this is a no-op (extending a normal Python
# list by [] certainly is), but protobuf lists work a little
# differently; if a field is never set, it is omitted from the
# resulting protobuf; a list that is explicitly set to be
# empty will get an (empty) entry in the protobuf. This
# difference is visible to our consumers, so make sure we emit
# an empty shape!
sparse_tensor_shape_proto.dim.extend([])
if shape_denotation and len(shape_denotation) != len(shape):
raise ValueError(
"Invalid shape_denotation. Must be of the same length as shape."
)
for i, d in enumerate(shape):
dim = sparse_tensor_shape_proto.dim.add()
if d is None:
pass
elif isinstance(d, int):
dim.dim_value = d
elif isinstance(d, str):
dim.dim_param = d
else:
raise ValueError(
f"Invalid item in shape: {d}. Needs to be of int or text."
)
if shape_denotation:
dim.denotation = shape_denotation[i]
return type_proto
[docs]
def make_sparse_tensor_value_info(
name: str,
elem_type: int,
shape: Sequence[str | int | None] | None,
doc_string: str = "",
shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
"""Makes a SparseTensor ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
sparse_tensor_type_proto = make_sparse_tensor_type_proto(
elem_type, shape, shape_denotation
)
value_info_proto.type.sparse_tensor_type.CopyFrom(
sparse_tensor_type_proto.sparse_tensor_type
)
return value_info_proto
[docs]
def make_sequence_type_proto(
inner_type_proto: TypeProto,
) -> TypeProto:
"""Makes a sequence TypeProto."""
type_proto = TypeProto()
type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
return type_proto
[docs]
def make_optional_type_proto(
inner_type_proto: TypeProto,
) -> TypeProto:
"""Makes an optional TypeProto."""
type_proto = TypeProto()
type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
return type_proto
[docs]
def make_map_type_proto(
key_type: int,
value_type: TypeProto,
) -> TypeProto:
"""Makes a map TypeProto."""
type_proto = TypeProto()
type_proto.map_type.key_type = key_type
type_proto.map_type.value_type.CopyFrom(value_type)
return type_proto
[docs]
def make_value_info(
name: str,
type_proto: TypeProto,
doc_string: str = "",
) -> ValueInfoProto:
"""Makes a ValueInfoProto with the given type_proto."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
value_info_proto.type.CopyFrom(type_proto)
return value_info_proto
def _sanitize_str(s: str | bytes) -> str:
if isinstance(s, str):
sanitized = s
elif isinstance(s, bytes):
sanitized = s.decode("utf-8", errors="ignore")
else:
sanitized = str(s)
if len(sanitized) < 64: # noqa: PLR2004
return sanitized
return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>"
[docs]
def make_tensor_sequence_value_info(
name: str,
elem_type: int,
shape: Sequence[str | int | None] | None,
doc_string: str = "",
elem_shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
"""Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation)
sequence_type_proto = make_sequence_type_proto(tensor_type_proto)
value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type)
return value_info_proto
[docs]
def printable_attribute(
attr: AttributeProto, subgraphs: bool = False
) -> str | tuple[str, list[GraphProto]]:
content = []
content.append(attr.name)
content.append("=")
def str_float(f: float) -> str:
# NB: Different Python versions print different numbers of trailing
# decimals, specifying this explicitly keeps it consistent for all
# versions
return f"{f:.15g}"
def str_int(i: int) -> str:
return str(i)
_T = TypeVar("_T")
def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str:
return "[" + ", ".join(map(str_elem, xs)) + "]"
# for now, this logic should continue to work as long as we are running on a proto3
# implementation. If/when we switch to proto3, we will need to use attr.type
# To support printing subgraphs, if we find a graph attribute, print out
# its name here and pass the graph itself up to the caller for later
# printing.
graphs = []
if attr.HasField("f"):
content.append(str_float(attr.f))
elif attr.HasField("i"):
content.append(str_int(attr.i))
elif attr.HasField("s"):
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
content.append(repr(_sanitize_str(attr.s)))
elif attr.HasField("t"):
if len(attr.t.dims) > 0:
content.append("<Tensor>")
else:
# special case to print scalars
field = tensor_dtype_to_field(attr.t.data_type)
content.append(f"<Scalar Tensor {getattr(attr.t, field)}>")
elif attr.HasField("g"):
content.append(f"<graph {attr.g.name}>")
graphs.append(attr.g)
elif attr.HasField("tp"):
content.append(f"<Type Proto {attr.tp}>")
elif attr.floats:
content.append(str_list(str_float, attr.floats))
elif attr.ints:
content.append(str_list(str_int, attr.ints))
elif attr.strings:
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
content.append(str(list(map(_sanitize_str, attr.strings))))
elif attr.tensors:
content.append("[<Tensor>, ...]")
elif attr.type_protos:
content.append("[")
for i, tp in enumerate(attr.type_protos):
comma = "," if i != len(attr.type_protos) - 1 else ""
content.append(f"<Type Proto {tp}>{comma}")
content.append("]")
elif attr.graphs:
content.append("[")
for i, g in enumerate(attr.graphs):
comma = "," if i != len(attr.graphs) - 1 else ""
content.append(f"<graph {g.name}>{comma}")
content.append("]")
graphs.extend(attr.graphs)
else:
content.append("<Unknown>")
if subgraphs:
return " ".join(content), graphs
return " ".join(content)
[docs]
def printable_dim(dim: TensorShapeProto.Dimension) -> str:
which = dim.WhichOneof("value")
if which is None:
return "?"
return str(getattr(dim, which))
[docs]
def printable_type(t: TypeProto) -> str:
if t.WhichOneof("value") == "tensor_type":
s = TensorProto.DataType.Name(t.tensor_type.elem_type)
if t.tensor_type.HasField("shape"):
if len(t.tensor_type.shape.dim):
s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim)))
else:
s += ", scalar"
return s
if t.WhichOneof("value") is None:
return ""
return f"Unknown type {t.WhichOneof('value')}"
[docs]
def printable_value_info(v: ValueInfoProto) -> str:
s = f"%{v.name}"
if v.type:
s = f"{s}[{printable_type(v.type)}]"
return s
[docs]
def printable_tensor_proto(t: TensorProto) -> str:
s = f"%{t.name}["
s += TensorProto.DataType.Name(t.data_type)
if t.dims is not None:
if len(t.dims):
s += str(", " + "x".join(map(str, t.dims)))
else:
s += ", scalar"
s += "]"
return s
[docs]
def printable_node(
node: NodeProto, prefix: str = "", subgraphs: bool = False
) -> str | tuple[str, list[GraphProto]]:
content = []
if len(node.output):
content.append(", ".join([f"%{name}" for name in node.output]))
content.append("=")
# To deal with nested graphs
graphs: list[GraphProto] = []
printed_attrs = []
for attr in node.attribute:
if subgraphs:
printed_attr_subgraphs = printable_attribute(attr, subgraphs)
if not isinstance(printed_attr_subgraphs[1], list):
raise TypeError(
f"printed_attr_subgraphs[1] must be an instance of {list}."
)
graphs.extend(printed_attr_subgraphs[1])
printed_attrs.append(printed_attr_subgraphs[0])
else:
printed = printable_attribute(attr)
if not isinstance(printed, str):
raise TypeError(f"printed must be an instance of {str}.")
printed_attrs.append(printed)
printed_attributes = ", ".join(sorted(printed_attrs))
printed_inputs = ", ".join([f"%{name}" for name in node.input])
if node.attribute:
content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})")
else:
content.append(f"{node.op_type}({printed_inputs})")
if subgraphs:
return prefix + " ".join(content), graphs
return prefix + " ".join(content)
[docs]
def printable_graph(graph: GraphProto, prefix: str = "") -> str:
"""Display a GraphProto as a string.
Args:
graph (GraphProto): the graph to display
prefix (string): prefix of every line
Returns:
string
"""
content = []
indent = prefix + " "
# header
header = ["graph", graph.name]
initializers = {t.name for t in graph.initializer}
if len(graph.input):
header.append("(")
in_strs = [] # required inputs
in_with_init_strs: list = (
[]
) # optional inputs with initializer providing default value
for inp in graph.input:
if inp.name not in initializers:
in_strs.append(printable_value_info(inp))
else:
in_with_init_strs.append(printable_value_info(inp))
if in_strs:
content.append(prefix + " ".join(header))
header = []
for line in in_strs:
content.append(prefix + " " + line) # noqa: PERF401
header.append(")")
if in_with_init_strs:
header.append("optional inputs with matching initializers (")
content.append(prefix + " ".join(header))
header = []
for line in in_with_init_strs:
content.append(prefix + " " + line) # noqa: PERF401
header.append(")")
# from IR 4 onwards an initializer is not required to have a matching graph input
# so output the name, type and shape of those as well
if len(in_with_init_strs) < len(initializers):
graph_inputs = {i.name for i in graph.input}
init_strs = [
printable_tensor_proto(i)
for i in graph.initializer
if i.name not in graph_inputs
]
header.append("initializers (")
content.append(prefix + " ".join(header))
header = []
for line in init_strs:
content.append(prefix + " " + line) # noqa: PERF401
header.append(")")
header.append("{")
content.append(prefix + " ".join(header))
graphs: list[GraphProto] = []
# body
for node in graph.node:
contents_subgraphs = printable_node(node, indent, subgraphs=True)
if not isinstance(contents_subgraphs[1], list):
raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.")
content.append(contents_subgraphs[0])
graphs.extend(contents_subgraphs[1])
# tail
tail = ["return"]
if len(graph.output):
tail.append(", ".join([f"%{out.name}" for out in graph.output]))
content.append(indent + " ".join(tail))
# closing bracket
content.append(prefix + "}")
for g in graphs:
content.append("\n" + printable_graph(g)) # noqa: PERF401
return "\n".join(content)
[docs]
def strip_doc_string(proto: google.protobuf.message.Message) -> None:
"""Empties `doc_string` field on any nested protobuf messages"""
if not isinstance(proto, google.protobuf.message.Message):
raise TypeError(
f"proto must be an instance of {google.protobuf.message.Message}."
)
for descriptor in proto.DESCRIPTOR.fields:
if descriptor.name == "doc_string":
proto.ClearField(descriptor.name)
elif descriptor.type == descriptor.TYPE_MESSAGE:
if descriptor.label == descriptor.LABEL_REPEATED:
for x in getattr(proto, descriptor.name):
strip_doc_string(x)
elif proto.HasField(descriptor.name):
strip_doc_string(getattr(proto, descriptor.name))
[docs]
def make_training_info(
algorithm: GraphProto,
algorithm_bindings: AssignmentBindingType,
initialization: GraphProto | None,
initialization_bindings: AssignmentBindingType | None,
) -> TrainingInfoProto:
training_info = TrainingInfoProto()
training_info.algorithm.CopyFrom(algorithm)
for k, v in algorithm_bindings:
binding = training_info.update_binding.add()
binding.key = k
binding.value = v
if initialization:
training_info.initialization.CopyFrom(initialization)
if initialization_bindings:
for k, v in initialization_bindings:
binding = training_info.initialization_binding.add()
binding.key = k
binding.value = v
return training_info
# Following functions are used for mapping
[docs]
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
"""Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor.
Args:
tensor_dtype: TensorProto's data_type
Returns:
numpy's data_type
"""
return mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype
[docs]
def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int:
"""Convert a TensorProto's data_type to corresponding data_type for storage.
Args:
tensor_dtype: TensorProto's data_type
Returns:
data_type for storage
"""
return mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
[docs]
def tensor_dtype_to_string(tensor_dtype: int) -> str:
"""Get the name of given TensorProto's data_type.
Args:
tensor_dtype: TensorProto's data_type
Returns:
the name of data_type
"""
return mapping.TENSOR_TYPE_MAP[tensor_dtype].name
[docs]
def tensor_dtype_to_field(tensor_dtype: int) -> str:
"""Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors.
Args:
tensor_dtype: TensorProto's data_type
Returns:
field name
"""
return mapping._STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
]
[docs]
def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> int:
"""Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors.
Args:
np_dtype: numpy's data_type
Returns:
TensorsProto's data_type
"""
if np_dtype in mapping._NP_TYPE_TO_TENSOR_TYPE:
return cast(
int,
mapping._NP_TYPE_TO_TENSOR_TYPE[np_dtype],
)
if np.issubdtype(np_dtype, np.str_):
return TensorProto.STRING
if np_dtype in {
custom_np_types.bfloat16,
custom_np_types.float8e4m3fn,
custom_np_types.float8e4m3fnuz,
custom_np_types.float8e5m2,
custom_np_types.float8e5m2fnuz,
custom_np_types.int4,
custom_np_types.uint4,
custom_np_types.float4e2m1,
}:
return custom_np_types.mapping_name_to_data_type[np_dtype.descr[0][0]]
raise ValueError(
f"Unable to convert type {np_dtype!r} into TensorProto element type."
)
[docs]
def get_all_tensor_dtypes() -> KeysView[int]:
"""Get all tensor types from TensorProto.
Returns:
all tensor types from TensorProto
"""
return mapping.TENSOR_TYPE_MAP.keys()
_ATTRIBUTE_TYPE_TO_STR: dict[int, str] = {
k: v for v, k in AttributeProto.AttributeType.items()
}
def _attr_type_to_str(attr_type: int) -> str:
"""Convert AttributeProto type to string.
Args:
attr_type: AttributeProto type.
Returns:
String representing the supplied attr_type.
"""
if attr_type in AttributeProto.AttributeType.values():
return _ATTRIBUTE_TYPE_TO_STR[attr_type]
return AttributeProto.AttributeType.keys()[0]