Source code for onnx.helper

# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import collections.abc
import numbers
import struct
from cmath import isnan
from typing import (
    Any,
    Callable,
    Dict,
    KeysView,
    List,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    cast,
)

import google.protobuf.message
import numpy as np

import onnx._custom_element_types as custom_np_types
from onnx import (
    IR_VERSION,
    AttributeProto,
    FunctionProto,
    GraphProto,
    MapProto,
    ModelProto,
    NodeProto,
    OperatorSetIdProto,
    OptionalProto,
    SequenceProto,
    SparseTensorProto,
    TensorProto,
    TensorShapeProto,
    TrainingInfoProto,
    TypeProto,
    ValueInfoProto,
    defs,
    mapping,
    subbyte,
)

VersionRowType = Union[Tuple[str, int, int, int], Tuple[str, int, int, int, int]]
VersionTableType = List[VersionRowType]
AssignmentBindingType = List[Tuple[str, str]]

# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
# Both must be updated whenever a new version of ONNX is released.
VERSION_TABLE: VersionTableType = [
    # Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
    ("1.0", 3, 1, 1),
    ("1.1", 3, 5, 1),
    ("1.1.2", 3, 6, 1),
    ("1.2", 3, 7, 1),
    ("1.3", 3, 8, 1),
    ("1.4.1", 4, 9, 1),
    ("1.5.0", 5, 10, 1),
    ("1.6.0", 6, 11, 2),
    ("1.7.0", 7, 12, 2, 1),
    ("1.8.0", 7, 13, 2, 1),
    ("1.8.1", 7, 13, 2, 1),
    ("1.9.0", 7, 14, 2, 1),
    ("1.10.0", 8, 15, 2, 1),
    ("1.10.1", 8, 15, 2, 1),
    ("1.10.2", 8, 15, 2, 1),
    ("1.11.0", 8, 16, 3, 1),
    ("1.12.0", 8, 17, 3, 1),
    ("1.13.0", 8, 18, 3, 1),
    ("1.13.1", 8, 18, 3, 1),
    ("1.14.0", 9, 19, 3, 1),
    ("1.14.1", 9, 19, 3, 1),
    ("1.15.0", 9, 20, 4, 1),
    ("1.16.0", 10, 21, 5, 1),
    ("1.17.0", 10, 22, 5, 1),
]

VersionMapType = Dict[Tuple[str, int], int]


[docs] def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType: """Create a map from (opset-domain, opset-version) to ir-version from above table.""" result: VersionMapType = {} def process(release_version: str, ir_version: int, *args: Any) -> None: del release_version # Unused for pair in zip(["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args): if pair not in result: result[pair] = ir_version if pair[0] == "ai.onnx.training": result["ai.onnx.preview.training", pair[1]] = ir_version for row in table: process(*row) return result
OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE)
[docs] def find_min_ir_version_for( opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False ) -> int: """Given list of opset ids, determine minimum IR version required. Args: opsetidlist: A sequence of OperatorSetIdProto. ignore_unknown: If True, ignore unknown domain and return default minimum version for that domain. Returns: The minimum IR version required (integer) """ default_min_version = 3 def find_min(domain: str | None, version: int) -> int: key = (domain or "ai.onnx", version) if key in OP_SET_ID_VERSION_MAP: return OP_SET_ID_VERSION_MAP[key] if ignore_unknown: return default_min_version raise ValueError("Unsupported opset-version.") if opsetidlist: return max(find_min(x.domain, x.version) for x in opsetidlist) return default_min_version # if no opsets specified
[docs] def make_node( op_type: str, inputs: Sequence[str], outputs: Sequence[str], name: str | None = None, doc_string: str | None = None, domain: str | None = None, overload: str | None = None, **kwargs: Any, ) -> NodeProto: """Construct a NodeProto. Args: op_type (string): The name of the operator to construct inputs (list of string): list of input names outputs (list of string): list of output names name (string, default None): optional unique identifier for NodeProto doc_string (string, default None): optional documentation string for NodeProto domain (string, default None): optional domain for NodeProto. If it's None, we will just use default domain (which is empty) overload (string, default None): optional field, used to resolve calls to model-local functions **kwargs (dict): the attributes of the node. The acceptable values are documented in :func:`make_attribute`. Returns: NodeProto """ node = NodeProto() node.op_type = op_type node.input.extend(inputs) node.output.extend(outputs) if name: node.name = name if doc_string: node.doc_string = doc_string if domain is not None: node.domain = domain if overload is not None: node.overload = overload if kwargs: node.attribute.extend( make_attribute(key, value) for key, value in sorted(kwargs.items()) if value is not None ) return node
[docs] def make_operatorsetid( domain: str, version: int, ) -> OperatorSetIdProto: """Construct an OperatorSetIdProto. Args: domain (string): The domain of the operator set id version (integer): Version of operator set id Returns: OperatorSetIdProto """ operatorsetid = OperatorSetIdProto() operatorsetid.domain = domain operatorsetid.version = version return operatorsetid
[docs] def make_graph( nodes: Sequence[NodeProto], name: str, inputs: Sequence[ValueInfoProto], outputs: Sequence[ValueInfoProto], initializer: Sequence[TensorProto] | None = None, doc_string: str | None = None, value_info: Sequence[ValueInfoProto] | None = None, sparse_initializer: Sequence[SparseTensorProto] | None = None, ) -> GraphProto: """Construct a GraphProto Args: nodes: list of NodeProto name (string): graph name inputs: list of ValueInfoProto outputs: list of ValueInfoProto initializer: list of TensorProto doc_string (string): graph documentation value_info: list of ValueInfoProto sparse_initializer: list of SparseTensorProto Returns: GraphProto """ if initializer is None: initializer = [] if sparse_initializer is None: sparse_initializer = [] if value_info is None: value_info = [] graph = GraphProto() graph.node.extend(nodes) graph.name = name graph.input.extend(inputs) graph.output.extend(outputs) graph.initializer.extend(initializer) graph.sparse_initializer.extend(sparse_initializer) graph.value_info.extend(value_info) if doc_string: graph.doc_string = doc_string return graph
[docs] def make_opsetid(domain: str, version: int) -> OperatorSetIdProto: """Construct an OperatorSetIdProto. Args: domain (string): The domain of the operator set id version (integer): Version of operator set id Returns: OperatorSetIdProto """ opsetid = OperatorSetIdProto() opsetid.domain = domain opsetid.version = version return opsetid
[docs] def make_function( domain: str, fname: str, inputs: Sequence[str], outputs: Sequence[str], nodes: Sequence[NodeProto], opset_imports: Sequence[OperatorSetIdProto], attributes: Sequence[str] | None = None, attribute_protos: Sequence[AttributeProto] | None = None, doc_string: str | None = None, overload: str | None = None, value_info: Sequence[ValueInfoProto] | None = None, ) -> FunctionProto: if attributes is None: attributes = [] if attribute_protos is None: attribute_protos = [] if value_info is None: value_info = [] f = FunctionProto() f.domain = domain f.name = fname f.input.extend(inputs) f.output.extend(outputs) f.node.extend(nodes) f.opset_import.extend(opset_imports) f.attribute.extend(attributes) f.attribute_proto.extend(attribute_protos) if doc_string: f.doc_string = doc_string if overload is not None: f.overload = overload f.value_info.extend(value_info) return f
[docs] def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto: """Construct a ModelProto Args: graph (GraphProto): *make_graph* returns **kwargs: any attribute to add to the returned instance Returns: ModelProto """ model = ModelProto() # Touch model.ir_version so it is stored as the version from which it is # generated. model.ir_version = IR_VERSION model.graph.CopyFrom(graph) opset_imports: Sequence[OperatorSetIdProto] | None = kwargs.pop( "opset_imports", None ) if opset_imports is not None: model.opset_import.extend(opset_imports) else: # Default import imp = model.opset_import.add() imp.version = defs.onnx_opset_version() functions: Sequence[FunctionProto] | None = kwargs.pop("functions", None) if functions is not None: model.functions.extend(functions) for k, v in kwargs.items(): # TODO: Does this work with repeated fields? setattr(model, k, v) return model
# An extension of make_model that infers an IR_VERSION for the model, # if not specified, using a best-effort-basis.
[docs] def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto: ir_version_field = "ir_version" if ir_version_field not in kwargs: opset_imports_field = "opset_imports" imports = kwargs.get(opset_imports_field, []) kwargs[ir_version_field] = find_min_ir_version_for(imports) return make_model(graph, **kwargs)
[docs] def set_metadata_props( proto: ( ModelProto | GraphProto | FunctionProto | NodeProto | TensorProto | ValueInfoProto ), dict_value: dict[str, str], ) -> None: del proto.metadata_props[:] for k, v in dict_value.items(): entry = proto.metadata_props.add() entry.key = k entry.value = v
[docs] def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None: set_metadata_props(model, dict_value)
[docs] def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]: return [ (ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) # type: ignore[misc] for i in range(len(ca) * 2) ]
# convert a float32 value to a bfloat16 (as int) # By default, this conversion rounds-to-nearest-even and supports NaN # Setting `truncate` to True enables a simpler conversion. In this mode the # conversion is performed by simply dropping the 2 least significant bytes of # the significand. In this mode an error of up to 1 bit may be introduced and # preservation of NaN values is not be guaranteed.
[docs] def float32_to_bfloat16(fval: float, truncate: bool = False) -> int: ival = int.from_bytes(struct.pack("<f", fval), "little") if truncate: return ival >> 16 # NaN requires at least 1 significand bit set if isnan(fval): return 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000 # drop bottom 16-bits # round remaining bits using round-to-nearest-even rounded = ((ival >> 16) & 1) + 0x7FFF return (ival + rounded) >> 16
[docs] def float32_to_float8e4m3( # noqa: PLR0911 fval: float, scale: float = 1.0, fn: bool = True, uz: bool = False, saturate: bool = True, ) -> int: """Convert a float32 value to a float8, e4m3 (as int). See :ref:`onnx-detail-float8` for technical details. Args: fval: float to convert scale: scale, divide *fval* by *scale* before casting it fn: no infinite values uz: no negative zero saturate: if True, any value out of range included inf becomes the maximum value, otherwise, it becomes NaN. The description of operator Cast fully describes the differences. Returns: converted float """ if not fn: raise NotImplementedError( "float32_to_float8e4m3 not implemented with fn=False." ) x = fval / scale b = int.from_bytes(struct.pack("<f", np.float32(x)), "little") ret = (b & 0x80000000) >> 24 # sign if uz: if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 return 0x80 if np.isinf(x): if saturate: return ret | 127 return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e < 116: # noqa: PLR2004 ret = 0 elif e < 120: # noqa: PLR2004 # denormalized number ex = e - 119 if ex >= -2: # noqa: PLR2004 ret |= 1 << (2 + ex) ret |= m >> (21 - ex) elif m > 0: ret |= 1 else: ret = 0 mask = 1 << (20 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 135: # noqa: PLR2004 # normalized number ex = e - 119 # 127 - 8 if ex == 0: ret |= 0x4 ret |= m >> 21 else: ret |= ex << 3 ret |= m >> 20 if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7F: # noqa: PLR2004 # rounding ret += 1 elif not saturate: return 0x80 elif saturate: ret |= 0x7F # 01111110 else: ret = 0x80 return int(ret) else: if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 return 0x7F | ret if np.isinf(x): if saturate: return ret | 126 return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 117: # noqa: PLR2004 pass elif e < 121: # noqa: PLR2004 # denormalized number ex = e - 120 if ex >= -2: # noqa: PLR2004 ret |= 1 << (2 + ex) ret |= m >> (21 - ex) elif m > 0: ret |= 1 mask = 1 << (20 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 136: # noqa: PLR2004 # normalized number ex = e - 120 if ex == 0: ret |= 0x4 ret |= m >> 21 else: ret |= ex << 3 ret |= m >> 20 if (ret & 0x7F) == 0x7F: # noqa: PLR2004 ret &= 0xFE if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7E: # noqa: PLR2004 # rounding ret += 1 elif not saturate: ret |= 0x7F elif saturate: ret |= 126 # 01111110 else: ret |= 0x7F return int(ret)
[docs] def float32_to_float8e5m2( # noqa: PLR0911 fval: float, scale: float = 1.0, fn: bool = False, uz: bool = False, saturate: bool = True, ) -> int: """Convert a float32 value to a float8, e5m2 (as int). Args: fval: float to convert scale: scale, divide *fval* by *scale* before casting it fn: no infinite values uz: no negative zero saturate: if True, any value out of range included inf becomes the maximum value, otherwise, it becomes NaN. The description of operator Cast fully describes the differences. Returns: converted float """ x = fval / scale b = int.from_bytes(struct.pack("<f", np.float32(x)), "little") ret = (b & 0x80000000) >> 24 # sign if fn and uz: if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 return 0x80 if (b & 0x7FFFFFFF) == 0x7F800000: # noqa: PLR2004 # inf if saturate: return ret | 0x7F return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e < 109: # noqa: PLR2004 ret = 0 elif e < 112: # noqa: PLR2004 # denormalized number ex = e - 111 if ex >= -1: ret |= 1 << (1 + ex) ret |= m >> (22 - ex) elif m > 0: ret |= 1 else: ret = 0 mask = 1 << (21 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 143: # noqa: PLR2004 # normalized number ex = e - 111 ret |= ex << 2 ret |= m >> 21 if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7F: # noqa: PLR2004 # rounding ret += 1 elif not saturate: ret = 0x80 elif e == 255 and m == 0: # inf # noqa: PLR2004 ret = 0x80 elif saturate: ret |= 0x7F # last possible number else: ret = 0x80 return int(ret) elif not fn and not uz: if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 return 0x7F | ret if np.isinf(x): if saturate: return 0x7B | ret return 0x7C | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 110: # noqa: PLR2004 pass elif e < 113: # noqa: PLR2004 # denormalized number ex = e - 112 if ex >= -1: ret |= 1 << (1 + ex) ret |= m >> (22 - ex) elif m > 0: ret |= 1 mask = 1 << (21 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 143: # noqa: PLR2004 # normalized number ex = e - 112 ret |= ex << 2 ret |= m >> 21 if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7B: # noqa: PLR2004 # rounding ret += 1 elif saturate: ret |= 0x7B else: ret |= 0x7C elif saturate: ret |= 0x7B else: ret |= 0x7C return int(ret) else: raise NotImplementedError("fn and uz must be both False or True.")
[docs] def pack_float32_to_4bit(array: np.ndarray | Sequence, signed: bool) -> np.ndarray: """Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte. See :ref:`onnx-detail-int4` for technical details. Args: array: array of float to convert and pack signed: Whether the 4 bit variant is signed or unsigned Returns: Packed array with size `ceil(farray.size/2)` (single dimension). """ if not isinstance(array, np.ndarray): array = np.asarray(array, dtype=np.float32) array_flat = array.ravel() is_odd_volume = np.prod(array.shape) % 2 == 1 if is_odd_volume: array_flat = np.append(array_flat, np.array([0])) def single_func(x, y) -> np.ndarray: return subbyte.float32x2_to_4bitx2(x, y, signed) func = np.frompyfunc(single_func, 2, 1) arr: np.ndarray = func(array_flat[0::2], array_flat[1::2]) return arr.astype(np.uint8)
def pack_float32_to_float4e2m1(array: np.ndarray | Sequence) -> np.ndarray: """Convert an array of float32 value to float4e2m1 and pack every two concecutive elements in a byte. See :ref:`onnx-detail-float4` for technical details. Args: array: array of float to convert and pack Returns: Packed array of float4e2m1 (as uint8) with size `ceil(farray.size/2)` (single dimension). """ if not isinstance(array, np.ndarray): array = np.asarray(array, dtype=np.float32) array_flat = array.ravel() is_odd_volume = np.prod(array.shape) % 2 == 1 if is_odd_volume: array_flat = np.append(array_flat, np.array([0])) arr = subbyte.float32x2_to_float4e2m1x2(array_flat[0::2], array_flat[1::2]) return arr.astype(np.uint8)
[docs] def make_tensor( name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = False ) -> TensorProto: """Make a TensorProto with specified arguments. If raw is False, this function will choose the corresponding proto field to store the values based on data_type. If raw is True, use "raw_data" proto field to store the values, and values should be of type bytes in this case. Args: name (string): tensor name data_type (int): a value such as onnx.TensorProto.FLOAT dims (List[int]): shape vals: values raw (bool): if True, vals contains the serialized content of the tensor, otherwise, vals should be a list of values of the type defined by *data_type* Returns: TensorProto """ tensor = TensorProto() tensor.data_type = data_type tensor.name = name if data_type == TensorProto.STRING and raw: raise TypeError("Can not use raw_data to store string type.") np_dtype = tensor_dtype_to_np_dtype(data_type) # Check number of vals specified equals tensor size expected_size: float = 1 if raw: # NumPy doesn't have BFLOAT16. TENSOR_TYPE_MAP maps it to float32, which has the wrong itemsize. if data_type == TensorProto.BFLOAT16: expected_size = 2 elif data_type in ( TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ, TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, ): expected_size = 1 # NumPy doesn't have INT4/FP4. It is packed in couples to UINT8 buffers. elif data_type in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1): expected_size = 0.5 else: expected_size = np_dtype.itemsize if isinstance(vals, np.ndarray) and len(vals.shape) > 1: vals = vals.flatten() for d in dims: expected_size *= d if len(vals) != expected_size: # padding of half a byte is acceptable for 4bit types if not ( data_type in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1) and len(vals) == expected_size + 0.5 ): raise ValueError( f"Number of values does not match tensor's size. Expected {expected_size}, but it is {len(vals)}. " ) if raw: tensor.raw_data = vals else: if data_type in (TensorProto.COMPLEX64, TensorProto.COMPLEX128): vals = split_complex_to_pairs(vals) elif data_type == TensorProto.FLOAT16: vals = ( np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist() ) elif data_type in ( TensorProto.BFLOAT16, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ, TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, ): fcast = { TensorProto.BFLOAT16: float32_to_bfloat16, TensorProto.FLOAT8E4M3FN: float32_to_float8e4m3, TensorProto.FLOAT8E4M3FNUZ: lambda *args: float32_to_float8e4m3( # type: ignore[misc] *args, uz=True ), TensorProto.FLOAT8E5M2: float32_to_float8e5m2, TensorProto.FLOAT8E5M2FNUZ: lambda *args: float32_to_float8e5m2( # type: ignore[misc] *args, fn=True, uz=True ), }[ data_type # type: ignore[index] ] vals = list( map( # type: ignore[call-overload] fcast, np.array(vals).astype(np_dtype).flatten().tolist(), ) ) elif data_type in ( TensorProto.UINT4, TensorProto.INT4, ): signed = data_type == TensorProto.INT4 # Two packed 4-bit values must be represented as a single uint8 value. # Therefore, pack_float32_to_4bit() sets the dtype of the output vals # to uint8 regardless of the value of 'signed'. Using int8 would cause # the size of int4 tensors to increase ~5x if the tensor contains negative values (due to # the way negative values are serialized by protobuf). vals = pack_float32_to_4bit(vals, signed=signed).flatten().tolist() elif data_type == TensorProto.FLOAT4E2M1: vals = pack_float32_to_float4e2m1(vals).flatten().tolist() elif data_type == TensorProto.BOOL: vals = np.array(vals).astype(int) elif data_type == TensorProto.STRING: vals = np.array(vals).astype(bytes) field = tensor_dtype_to_field(data_type) getattr(tensor, field).extend(vals) tensor.dims.extend(dims) return tensor
[docs] def make_sparse_tensor( values: TensorProto, indices: TensorProto, dims: Sequence[int] ) -> SparseTensorProto: """Construct a SparseTensorProto Args: values (TensorProto): the values indices (TensorProto): the indices dims: the shape Returns: SparseTensorProto """ sparse = SparseTensorProto() sparse.values.CopyFrom(values) sparse.indices.CopyFrom(indices) sparse.dims.extend(dims) return sparse
[docs] def make_sequence( name: str, elem_type: SequenceProto.DataType, values: Sequence[Any], ) -> SequenceProto: """Make a Sequence with specified value arguments.""" sequence = SequenceProto() sequence.name = name sequence.elem_type = elem_type if elem_type == SequenceProto.UNDEFINED: return sequence attribute: Sequence | None = None if elem_type == SequenceProto.TENSOR: attribute = sequence.tensor_values elif elem_type == SequenceProto.SPARSE_TENSOR: attribute = sequence.sparse_tensor_values elif elem_type == SequenceProto.SEQUENCE: attribute = sequence.sequence_values elif elem_type == SequenceProto.MAP: attribute = sequence.map_values elif elem_type == OptionalProto.OPTIONAL: attribute = sequence.optional_values else: raise TypeError("The element type in the input sequence is not supported.") attribute.extend(values) return sequence
[docs] def make_map( name: str, key_type: int, keys: list[Any], values: SequenceProto ) -> MapProto: """Make a Map with specified key-value pair arguments. Criteria for conversion: - Keys and Values must have the same number of elements - Every key in keys must be of the same type - Every value in values must be of the same type """ map_proto = MapProto() valid_key_int_types = [ TensorProto.INT8, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64, TensorProto.UINT8, TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64, ] map_proto.name = name map_proto.key_type = key_type if key_type == TensorProto.STRING: map_proto.string_keys.extend(keys) elif key_type in valid_key_int_types: map_proto.keys.extend(keys) map_proto.values.CopyFrom(values) return map_proto
[docs] def make_optional( name: str, elem_type: OptionalProto.DataType, value: google.protobuf.message.Message | None, ) -> OptionalProto: """Make an Optional with specified value arguments.""" optional = OptionalProto() optional.name = name optional.elem_type = elem_type if elem_type == OptionalProto.UNDEFINED: return optional attribute: google.protobuf.message.Message | None = None if elem_type == OptionalProto.TENSOR: attribute = optional.tensor_value elif elem_type == OptionalProto.SPARSE_TENSOR: attribute = optional.sparse_tensor_value elif elem_type == OptionalProto.SEQUENCE: attribute = optional.sequence_value elif elem_type == OptionalProto.MAP: attribute = optional.map_value elif elem_type == OptionalProto.OPTIONAL: attribute = optional.optional_value else: raise TypeError("The element type in the input optional is not supported.") assert value is not None attribute.CopyFrom(value) # type: ignore[arg-type] return optional
def _to_bytes(value: str | bytes) -> bytes: """Coerce a string (or bytes) value into UTF-8 bytes.""" if isinstance(value, str): return value.encode("utf-8") return value
[docs] def make_attribute( key: str, value: Any, doc_string: str | None = None, attr_type: int | None = None, ) -> AttributeProto: """Makes an AttributeProto based on the value type.""" attr = AttributeProto() attr.name = key if doc_string: attr.doc_string = doc_string # Singular cases if isinstance(value, numbers.Integral): attr.i = int(value) attr.type = AttributeProto.INT elif isinstance(value, numbers.Real): attr.f = float(value) attr.type = AttributeProto.FLOAT elif isinstance(value, (str, bytes)): # Encode strings into utf-8 attr.s = _to_bytes(value) attr.type = AttributeProto.STRING elif isinstance(value, TensorProto): attr.t.CopyFrom(value) attr.type = AttributeProto.TENSOR elif isinstance(value, SparseTensorProto): attr.sparse_tensor.CopyFrom(value) attr.type = AttributeProto.SPARSE_TENSOR elif isinstance(value, GraphProto): attr.g.CopyFrom(value) attr.type = AttributeProto.GRAPH elif isinstance(value, TypeProto): attr.tp.CopyFrom(value) attr.type = AttributeProto.TYPE_PROTO # Iterable cases elif isinstance(value, collections.abc.Iterable): value = list(value) if len(value) == 0 and attr_type is None: raise ValueError( f"Could not infer attribute `{key}` type from empty iterator" ) if attr_type is None: types = {type(v) for v in value} for exp_t, exp_enum in ( (numbers.Integral, AttributeProto.INTS), (numbers.Real, AttributeProto.FLOATS), ((str, bytes), AttributeProto.STRINGS), (TensorProto, AttributeProto.TENSORS), (SparseTensorProto, AttributeProto.SPARSE_TENSORS), (GraphProto, AttributeProto.GRAPHS), (TypeProto, AttributeProto.TYPE_PROTOS), ): if all(issubclass(t, exp_t) for t in types): attr_type = exp_enum break if attr_type is None: raise ValueError( "Could not infer the attribute type from the elements of the passed Iterable value." ) if attr_type == AttributeProto.INTS: attr.ints.extend(value) attr.type = AttributeProto.INTS elif attr_type == AttributeProto.FLOATS: attr.floats.extend(value) attr.type = AttributeProto.FLOATS elif attr_type == AttributeProto.STRINGS: attr.strings.extend(_to_bytes(v) for v in value) attr.type = AttributeProto.STRINGS elif attr_type == AttributeProto.TENSORS: attr.tensors.extend(value) attr.type = AttributeProto.TENSORS elif attr_type == AttributeProto.SPARSE_TENSORS: attr.sparse_tensors.extend(value) attr.type = AttributeProto.SPARSE_TENSORS elif attr_type == AttributeProto.GRAPHS: attr.graphs.extend(value) attr.type = AttributeProto.GRAPHS elif attr_type == AttributeProto.TYPE_PROTOS: attr.type_protos.extend(value) attr.type = AttributeProto.TYPE_PROTOS else: raise AssertionError() # Should not reach since `ValueError` must be raised in attr_type checking else: raise TypeError(f"'{value}' is not an accepted attribute value.") if attr_type is not None and attr.type != attr_type: raise TypeError( f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})" ) return attr
[docs] def make_attribute_ref( name: str, attr_type: AttributeProto.AttributeType, doc_string: str | None = None ) -> AttributeProto: """Make an AttributeProto holding a reference to the parent function's attribute of given name and type.""" attr = AttributeProto() attr.name = name attr.type = attr_type if doc_string: attr.doc_string = doc_string return attr
[docs] def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911 if attr.ref_attr_name: raise ValueError(f"Cannot get value of reference attribute: {attr}") if attr.type == AttributeProto.FLOAT: return attr.f if attr.type == AttributeProto.INT: return attr.i if attr.type == AttributeProto.STRING: return attr.s if attr.type == AttributeProto.TENSOR: return attr.t if attr.type == AttributeProto.SPARSE_TENSOR: return attr.sparse_tensor if attr.type == AttributeProto.GRAPH: return attr.g if attr.type == AttributeProto.TYPE_PROTO: return attr.tp if attr.type == AttributeProto.FLOATS: return list(attr.floats) if attr.type == AttributeProto.INTS: return list(attr.ints) if attr.type == AttributeProto.STRINGS: return list(attr.strings) if attr.type == AttributeProto.TENSORS: return list(attr.tensors) if attr.type == AttributeProto.SPARSE_TENSORS: return list(attr.sparse_tensors) if attr.type == AttributeProto.GRAPHS: return list(attr.graphs) if attr.type == AttributeProto.TYPE_PROTOS: return list(attr.type_protos) if attr.type == AttributeProto.UNDEFINED: return None raise ValueError(f"Unsupported ONNX attribute: {attr}")
[docs] def get_node_attr_value(node: NodeProto, attr_name: str) -> Any: matching = [x for x in node.attribute if x.name == attr_name] if len(matching) > 1: raise ValueError(f"Node has multiple attributes with name {attr_name}") if len(matching) < 1: raise ValueError(f"Node has no attribute with name {attr_name}") return get_attribute_value(matching[0])
[docs] def make_empty_tensor_value_info(name: str) -> ValueInfoProto: value_info_proto = ValueInfoProto() value_info_proto.name = name return value_info_proto
[docs] def make_tensor_type_proto( elem_type: int, shape: Sequence[str | int | None] | None, shape_denotation: list[str] | None = None, ) -> TypeProto: """Makes a Tensor TypeProto based on the data type and shape.""" type_proto = TypeProto() tensor_type_proto = type_proto.tensor_type tensor_type_proto.elem_type = elem_type tensor_shape_proto = tensor_type_proto.shape if shape is not None: # You might think this is a no-op (extending a normal Python # list by [] certainly is), but protobuf lists work a little # differently; if a field is never set, it is omitted from the # resulting protobuf; a list that is explicitly set to be # empty will get an (empty) entry in the protobuf. This # difference is visible to our consumers, so make sure we emit # an empty shape! tensor_shape_proto.dim.extend([]) if shape_denotation and len(shape_denotation) != len(shape): raise ValueError( "Invalid shape_denotation. Must be of the same length as shape." ) for i, d in enumerate(shape): dim = tensor_shape_proto.dim.add() if d is None: pass elif isinstance(d, int): dim.dim_value = d elif isinstance(d, str): dim.dim_param = d else: raise ValueError( f"Invalid item in shape: {d}. Needs to be of int or str." ) if shape_denotation: dim.denotation = shape_denotation[i] return type_proto
[docs] def make_tensor_value_info( name: str, elem_type: int, shape: Sequence[str | int | None] | None, doc_string: str = "", shape_denotation: list[str] | None = None, ) -> ValueInfoProto: """Makes a ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation) value_info_proto.type.CopyFrom(tensor_type_proto) return value_info_proto
[docs] def make_sparse_tensor_type_proto( elem_type: int, shape: Sequence[str | int | None] | None, shape_denotation: list[str] | None = None, ) -> TypeProto: """Makes a SparseTensor TypeProto based on the data type and shape.""" type_proto = TypeProto() sparse_tensor_type_proto = type_proto.sparse_tensor_type sparse_tensor_type_proto.elem_type = elem_type sparse_tensor_shape_proto = sparse_tensor_type_proto.shape if shape is not None: # You might think this is a no-op (extending a normal Python # list by [] certainly is), but protobuf lists work a little # differently; if a field is never set, it is omitted from the # resulting protobuf; a list that is explicitly set to be # empty will get an (empty) entry in the protobuf. This # difference is visible to our consumers, so make sure we emit # an empty shape! sparse_tensor_shape_proto.dim.extend([]) if shape_denotation and len(shape_denotation) != len(shape): raise ValueError( "Invalid shape_denotation. Must be of the same length as shape." ) for i, d in enumerate(shape): dim = sparse_tensor_shape_proto.dim.add() if d is None: pass elif isinstance(d, int): dim.dim_value = d elif isinstance(d, str): dim.dim_param = d else: raise ValueError( f"Invalid item in shape: {d}. Needs to be of int or text." ) if shape_denotation: dim.denotation = shape_denotation[i] return type_proto
[docs] def make_sparse_tensor_value_info( name: str, elem_type: int, shape: Sequence[str | int | None] | None, doc_string: str = "", shape_denotation: list[str] | None = None, ) -> ValueInfoProto: """Makes a SparseTensor ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string sparse_tensor_type_proto = make_sparse_tensor_type_proto( elem_type, shape, shape_denotation ) value_info_proto.type.sparse_tensor_type.CopyFrom( sparse_tensor_type_proto.sparse_tensor_type ) return value_info_proto
[docs] def make_sequence_type_proto( inner_type_proto: TypeProto, ) -> TypeProto: """Makes a sequence TypeProto.""" type_proto = TypeProto() type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto) return type_proto
[docs] def make_optional_type_proto( inner_type_proto: TypeProto, ) -> TypeProto: """Makes an optional TypeProto.""" type_proto = TypeProto() type_proto.optional_type.elem_type.CopyFrom(inner_type_proto) return type_proto
[docs] def make_map_type_proto( key_type: int, value_type: TypeProto, ) -> TypeProto: """Makes a map TypeProto.""" type_proto = TypeProto() type_proto.map_type.key_type = key_type type_proto.map_type.value_type.CopyFrom(value_type) return type_proto
[docs] def make_value_info( name: str, type_proto: TypeProto, doc_string: str = "", ) -> ValueInfoProto: """Makes a ValueInfoProto with the given type_proto.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string value_info_proto.type.CopyFrom(type_proto) return value_info_proto
def _sanitize_str(s: str | bytes) -> str: if isinstance(s, str): sanitized = s elif isinstance(s, bytes): sanitized = s.decode("utf-8", errors="ignore") else: sanitized = str(s) if len(sanitized) < 64: # noqa: PLR2004 return sanitized return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>"
[docs] def make_tensor_sequence_value_info( name: str, elem_type: int, shape: Sequence[str | int | None] | None, doc_string: str = "", elem_shape_denotation: list[str] | None = None, ) -> ValueInfoProto: """Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name if doc_string: value_info_proto.doc_string = doc_string tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation) sequence_type_proto = make_sequence_type_proto(tensor_type_proto) value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type) return value_info_proto
[docs] def printable_attribute( attr: AttributeProto, subgraphs: bool = False ) -> str | tuple[str, list[GraphProto]]: content = [] content.append(attr.name) content.append("=") def str_float(f: float) -> str: # NB: Different Python versions print different numbers of trailing # decimals, specifying this explicitly keeps it consistent for all # versions return f"{f:.15g}" def str_int(i: int) -> str: return str(i) _T = TypeVar("_T") def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str: return "[" + ", ".join(map(str_elem, xs)) + "]" # for now, this logic should continue to work as long as we are running on a proto3 # implementation. If/when we switch to proto3, we will need to use attr.type # To support printing subgraphs, if we find a graph attribute, print out # its name here and pass the graph itself up to the caller for later # printing. graphs = [] if attr.HasField("f"): content.append(str_float(attr.f)) elif attr.HasField("i"): content.append(str_int(attr.i)) elif attr.HasField("s"): # TODO: Bit nervous about Python 2 / Python 3 determinism implications content.append(repr(_sanitize_str(attr.s))) elif attr.HasField("t"): if len(attr.t.dims) > 0: content.append("<Tensor>") else: # special case to print scalars field = tensor_dtype_to_field(attr.t.data_type) content.append(f"<Scalar Tensor {getattr(attr.t, field)}>") elif attr.HasField("g"): content.append(f"<graph {attr.g.name}>") graphs.append(attr.g) elif attr.HasField("tp"): content.append(f"<Type Proto {attr.tp}>") elif attr.floats: content.append(str_list(str_float, attr.floats)) elif attr.ints: content.append(str_list(str_int, attr.ints)) elif attr.strings: # TODO: Bit nervous about Python 2 / Python 3 determinism implications content.append(str(list(map(_sanitize_str, attr.strings)))) elif attr.tensors: content.append("[<Tensor>, ...]") elif attr.type_protos: content.append("[") for i, tp in enumerate(attr.type_protos): comma = "," if i != len(attr.type_protos) - 1 else "" content.append(f"<Type Proto {tp}>{comma}") content.append("]") elif attr.graphs: content.append("[") for i, g in enumerate(attr.graphs): comma = "," if i != len(attr.graphs) - 1 else "" content.append(f"<graph {g.name}>{comma}") content.append("]") graphs.extend(attr.graphs) else: content.append("<Unknown>") if subgraphs: return " ".join(content), graphs return " ".join(content)
[docs] def printable_dim(dim: TensorShapeProto.Dimension) -> str: which = dim.WhichOneof("value") if which is None: return "?" return str(getattr(dim, which))
[docs] def printable_type(t: TypeProto) -> str: if t.WhichOneof("value") == "tensor_type": s = TensorProto.DataType.Name(t.tensor_type.elem_type) if t.tensor_type.HasField("shape"): if len(t.tensor_type.shape.dim): s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim))) else: s += ", scalar" return s if t.WhichOneof("value") is None: return "" return f"Unknown type {t.WhichOneof('value')}"
[docs] def printable_value_info(v: ValueInfoProto) -> str: s = f"%{v.name}" if v.type: s = f"{s}[{printable_type(v.type)}]" return s
[docs] def printable_tensor_proto(t: TensorProto) -> str: s = f"%{t.name}[" s += TensorProto.DataType.Name(t.data_type) if t.dims is not None: if len(t.dims): s += str(", " + "x".join(map(str, t.dims))) else: s += ", scalar" s += "]" return s
[docs] def printable_node( node: NodeProto, prefix: str = "", subgraphs: bool = False ) -> str | tuple[str, list[GraphProto]]: content = [] if len(node.output): content.append(", ".join([f"%{name}" for name in node.output])) content.append("=") # To deal with nested graphs graphs: list[GraphProto] = [] printed_attrs = [] for attr in node.attribute: if subgraphs: printed_attr_subgraphs = printable_attribute(attr, subgraphs) if not isinstance(printed_attr_subgraphs[1], list): raise TypeError( f"printed_attr_subgraphs[1] must be an instance of {list}." ) graphs.extend(printed_attr_subgraphs[1]) printed_attrs.append(printed_attr_subgraphs[0]) else: printed = printable_attribute(attr) if not isinstance(printed, str): raise TypeError(f"printed must be an instance of {str}.") printed_attrs.append(printed) printed_attributes = ", ".join(sorted(printed_attrs)) printed_inputs = ", ".join([f"%{name}" for name in node.input]) if node.attribute: content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})") else: content.append(f"{node.op_type}({printed_inputs})") if subgraphs: return prefix + " ".join(content), graphs return prefix + " ".join(content)
[docs] def printable_graph(graph: GraphProto, prefix: str = "") -> str: """Display a GraphProto as a string. Args: graph (GraphProto): the graph to display prefix (string): prefix of every line Returns: string """ content = [] indent = prefix + " " # header header = ["graph", graph.name] initializers = {t.name for t in graph.initializer} if len(graph.input): header.append("(") in_strs = [] # required inputs in_with_init_strs: list = ( [] ) # optional inputs with initializer providing default value for inp in graph.input: if inp.name not in initializers: in_strs.append(printable_value_info(inp)) else: in_with_init_strs.append(printable_value_info(inp)) if in_strs: content.append(prefix + " ".join(header)) header = [] for line in in_strs: content.append(prefix + " " + line) # noqa: PERF401 header.append(")") if in_with_init_strs: header.append("optional inputs with matching initializers (") content.append(prefix + " ".join(header)) header = [] for line in in_with_init_strs: content.append(prefix + " " + line) # noqa: PERF401 header.append(")") # from IR 4 onwards an initializer is not required to have a matching graph input # so output the name, type and shape of those as well if len(in_with_init_strs) < len(initializers): graph_inputs = {i.name for i in graph.input} init_strs = [ printable_tensor_proto(i) for i in graph.initializer if i.name not in graph_inputs ] header.append("initializers (") content.append(prefix + " ".join(header)) header = [] for line in init_strs: content.append(prefix + " " + line) # noqa: PERF401 header.append(")") header.append("{") content.append(prefix + " ".join(header)) graphs: list[GraphProto] = [] # body for node in graph.node: contents_subgraphs = printable_node(node, indent, subgraphs=True) if not isinstance(contents_subgraphs[1], list): raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.") content.append(contents_subgraphs[0]) graphs.extend(contents_subgraphs[1]) # tail tail = ["return"] if len(graph.output): tail.append(", ".join([f"%{out.name}" for out in graph.output])) content.append(indent + " ".join(tail)) # closing bracket content.append(prefix + "}") for g in graphs: content.append("\n" + printable_graph(g)) # noqa: PERF401 return "\n".join(content)
[docs] def strip_doc_string(proto: google.protobuf.message.Message) -> None: """Empties `doc_string` field on any nested protobuf messages""" if not isinstance(proto, google.protobuf.message.Message): raise TypeError( f"proto must be an instance of {google.protobuf.message.Message}." ) for descriptor in proto.DESCRIPTOR.fields: if descriptor.name == "doc_string": proto.ClearField(descriptor.name) elif descriptor.type == descriptor.TYPE_MESSAGE: if descriptor.label == descriptor.LABEL_REPEATED: for x in getattr(proto, descriptor.name): strip_doc_string(x) elif proto.HasField(descriptor.name): strip_doc_string(getattr(proto, descriptor.name))
[docs] def make_training_info( algorithm: GraphProto, algorithm_bindings: AssignmentBindingType, initialization: GraphProto | None, initialization_bindings: AssignmentBindingType | None, ) -> TrainingInfoProto: training_info = TrainingInfoProto() training_info.algorithm.CopyFrom(algorithm) for k, v in algorithm_bindings: binding = training_info.update_binding.add() binding.key = k binding.value = v if initialization: training_info.initialization.CopyFrom(initialization) if initialization_bindings: for k, v in initialization_bindings: binding = training_info.initialization_binding.add() binding.key = k binding.value = v return training_info
# Following functions are used for mapping
[docs] def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: """Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor. Args: tensor_dtype: TensorProto's data_type Returns: numpy's data_type """ return mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype
[docs] def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int: """Convert a TensorProto's data_type to corresponding data_type for storage. Args: tensor_dtype: TensorProto's data_type Returns: data_type for storage """ return mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
[docs] def tensor_dtype_to_string(tensor_dtype: int) -> str: """Get the name of given TensorProto's data_type. Args: tensor_dtype: TensorProto's data_type Returns: the name of data_type """ return mapping.TENSOR_TYPE_MAP[tensor_dtype].name
[docs] def tensor_dtype_to_field(tensor_dtype: int) -> str: """Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors. Args: tensor_dtype: TensorProto's data_type Returns: field name """ return mapping._STORAGE_TENSOR_TYPE_TO_FIELD[ mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype ]
[docs] def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> int: """Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors. Args: np_dtype: numpy's data_type Returns: TensorsProto's data_type """ if np_dtype in mapping._NP_TYPE_TO_TENSOR_TYPE: return cast( int, mapping._NP_TYPE_TO_TENSOR_TYPE[np_dtype], ) if np.issubdtype(np_dtype, np.str_): return TensorProto.STRING if np_dtype in { custom_np_types.bfloat16, custom_np_types.float8e4m3fn, custom_np_types.float8e4m3fnuz, custom_np_types.float8e5m2, custom_np_types.float8e5m2fnuz, custom_np_types.int4, custom_np_types.uint4, custom_np_types.float4e2m1, }: return custom_np_types.mapping_name_to_data_type[np_dtype.descr[0][0]] raise ValueError( f"Unable to convert type {np_dtype!r} into TensorProto element type." )
[docs] def get_all_tensor_dtypes() -> KeysView[int]: """Get all tensor types from TensorProto. Returns: all tensor types from TensorProto """ return mapping.TENSOR_TYPE_MAP.keys()
_ATTRIBUTE_TYPE_TO_STR: dict[int, str] = { k: v for v, k in AttributeProto.AttributeType.items() } def _attr_type_to_str(attr_type: int) -> str: """Convert AttributeProto type to string. Args: attr_type: AttributeProto type. Returns: String representing the supplied attr_type. """ if attr_type in AttributeProto.AttributeType.values(): return _ATTRIBUTE_TYPE_TO_STR[attr_type] return AttributeProto.AttributeType.keys()[0]