# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
__all__ = [
# Constants
"ONNX_ML",
"IR_VERSION",
"IR_VERSION_2017_10_10",
"IR_VERSION_2017_10_30",
"IR_VERSION_2017_11_3",
"IR_VERSION_2019_1_22",
"IR_VERSION_2019_3_18",
"IR_VERSION_2019_9_19",
"IR_VERSION_2020_5_8",
"IR_VERSION_2021_7_30",
"IR_VERSION_2023_5_5",
"IR_VERSION_2024_3_25",
"EXPERIMENTAL",
"STABLE",
# Modules
"checker",
"compose",
"defs",
"gen_proto",
"helper",
"hub",
"mapping",
"numpy_helper",
"parser",
"printer",
"shape_inference",
"utils",
"version_converter",
# Proto classes
"AttributeProto",
"FunctionProto",
"GraphProto",
"MapProto",
"ModelProto",
"NodeProto",
"OperatorProto",
"OperatorSetIdProto",
"OperatorSetProto",
"OperatorStatus",
"OptionalProto",
"SequenceProto",
"SparseTensorProto",
"StringStringEntryProto",
"TensorAnnotation",
"TensorProto",
"TensorShapeProto",
"TrainingInfoProto",
"TypeProto",
"ValueInfoProto",
"Version",
# Utility functions
"convert_model_to_external_data",
"load_external_data_for_model",
"load_model_from_string",
"load_model",
"load_tensor_from_string",
"load_tensor",
"save_model",
"save_tensor",
"write_external_data_tensors",
]
# isort:skip_file
import os
import typing
from typing import IO, Literal, Union
from onnx import serialization
from onnx.onnx_cpp2py_export import ONNX_ML
from onnx.external_data_helper import (
load_external_data_for_model,
write_external_data_tensors,
convert_model_to_external_data,
)
from onnx.onnx_pb import (
AttributeProto,
EXPERIMENTAL,
FunctionProto,
GraphProto,
IR_VERSION,
IR_VERSION_2017_10_10,
IR_VERSION_2017_10_30,
IR_VERSION_2017_11_3,
IR_VERSION_2019_1_22,
IR_VERSION_2019_3_18,
IR_VERSION_2019_9_19,
IR_VERSION_2020_5_8,
IR_VERSION_2021_7_30,
IR_VERSION_2023_5_5,
IR_VERSION_2024_3_25,
ModelProto,
NodeProto,
OperatorSetIdProto,
OperatorStatus,
STABLE,
SparseTensorProto,
StringStringEntryProto,
TensorAnnotation,
TensorProto,
TensorShapeProto,
TrainingInfoProto,
TypeProto,
ValueInfoProto,
Version,
)
from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto
from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
from onnx.version import version as __version__
# Import common subpackages so they're available when you 'import onnx'
from onnx import (
checker,
compose,
defs,
gen_proto,
helper,
hub,
mapping,
numpy_helper,
parser,
printer,
shape_inference,
utils,
version_converter,
)
# Supported model formats that can be loaded from and saved to
# The literals are formats with built-in support. But we also allow users to
# register their own formats. So we allow str as well.
_SupportedFormat = Union[
Literal["protobuf", "textproto", "onnxtxt", "json"], str # noqa: PYI051
]
# Default serialization format
_DEFAULT_FORMAT = "protobuf"
def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes:
if hasattr(f, "read") and callable(typing.cast(IO[bytes], f).read):
content = typing.cast(IO[bytes], f).read()
else:
f = typing.cast(Union[str, os.PathLike], f)
with open(f, "rb") as readable:
content = readable.read()
return content
def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None:
if hasattr(f, "write") and callable(typing.cast(IO[bytes], f).write):
typing.cast(IO[bytes], f).write(content)
else:
f = typing.cast(Union[str, os.PathLike], f)
with open(f, "wb") as writable:
writable.write(content)
def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None:
if isinstance(f, (str, os.PathLike)):
return os.path.abspath(f)
if hasattr(f, "name"):
assert f is not None
return os.path.abspath(f.name)
return None
def _get_serializer(
fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None
) -> serialization.ProtoSerializer:
"""Get the serializer for the given path and format from the serialization registry."""
# Use fmt if it is specified
if fmt is not None:
return serialization.registry.get(fmt)
if (file_path := _get_file_path(f)) is not None:
_, ext = os.path.splitext(file_path)
fmt = serialization.registry.get_format_from_file_extension(ext)
# Failed to resolve format if fmt is None. Use protobuf as default
fmt = fmt or _DEFAULT_FORMAT
assert fmt is not None
return serialization.registry.get(fmt)
def load_model(
f: IO[bytes] | str | os.PathLike,
format: _SupportedFormat | None = None, # noqa: A002
load_external_data: bool = True,
) -> ModelProto:
"""Loads a serialized ModelProto into memory.
Args:
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
load_external_data: Whether to load the external data.
Set to True if the data is under the same directory of the model.
If not, users need to call :func:`load_external_data_for_model`
with directory to load external data from.
Returns:
Loaded in-memory ModelProto.
"""
model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
if load_external_data:
model_filepath = _get_file_path(f)
if model_filepath:
base_dir = os.path.dirname(model_filepath)
load_external_data_for_model(model, base_dir)
return model
def load_tensor(
f: IO[bytes] | str | os.PathLike,
format: _SupportedFormat | None = None, # noqa: A002
) -> TensorProto:
"""Loads a serialized TensorProto into memory.
Args:
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
Returns:
Loaded in-memory TensorProto.
"""
return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto())
[docs]
def load_model_from_string(
s: bytes | str,
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
) -> ModelProto:
"""Loads a binary string (bytes) that contains serialized ModelProto.
Args:
s: a string, which contains serialized ModelProto
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
Returns:
Loaded in-memory ModelProto.
"""
return _get_serializer(format).deserialize_proto(s, ModelProto())
[docs]
def load_tensor_from_string(
s: bytes,
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002
) -> TensorProto:
"""Loads a binary string (bytes) that contains serialized TensorProto.
Args:
s: a string, which contains serialized TensorProto
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
Returns:
Loaded in-memory TensorProto.
"""
return _get_serializer(format).deserialize_proto(s, TensorProto())
def save_model(
proto: ModelProto | bytes,
f: IO[bytes] | str | os.PathLike,
format: _SupportedFormat | None = None, # noqa: A002
*,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = True,
location: str | None = None,
size_threshold: int = 1024,
convert_attribute: bool = False,
) -> None:
"""Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.
Args:
proto: should be a in-memory ModelProto
f: can be a file-like object (has "write" function) or a string containing
a file name or a pathlike object
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
save_as_external_data: If true, save tensors to external file(s).
all_tensors_to_one_file: Effective only if save_as_external_data is True.
If true, save all tensors to one external file specified by location.
If false, save each tensor to a file named with the tensor name.
location: Effective only if save_as_external_data is true.
Specify the external file that all tensors to save to.
Path is relative to the model path.
If not specified, will use the model name.
size_threshold: Effective only if save_as_external_data is True.
Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted
to external data. To convert every tensor with raw data to external data set size_threshold=0.
convert_attribute: Effective only if save_as_external_data is True.
If true, convert all tensors to external data
If false, convert only non-attribute tensors to external data
"""
if isinstance(proto, bytes):
proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto())
if save_as_external_data:
convert_model_to_external_data(
proto, all_tensors_to_one_file, location, size_threshold, convert_attribute
)
model_filepath = _get_file_path(f)
if model_filepath is not None:
basepath = os.path.dirname(model_filepath)
proto = write_external_data_tensors(proto, basepath)
serialized = _get_serializer(format, model_filepath).serialize_proto(proto)
_save_bytes(serialized, f)
def save_tensor(
proto: TensorProto,
f: IO[bytes] | str | os.PathLike,
format: _SupportedFormat | None = None, # noqa: A002
) -> None:
"""Saves the TensorProto to the specified path.
Args:
proto: should be a in-memory TensorProto
f: can be a file-like object (has "write" function) or a string
containing a file name or a pathlike object.
format: The serialization format. When it is not specified, it is inferred
from the file extension when ``f`` is a path. If not specified _and_
``f`` is not a path, 'protobuf' is used. The encoding is assumed to
be "utf-8" when the format is a text format.
"""
serialized = _get_serializer(format, f).serialize_proto(proto)
_save_bytes(serialized, f)
# For backward compatibility
load = load_model
load_from_string = load_model_from_string
save = save_model