# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections import namedtuple
from collections.abc import Sequence
from typing import Any, NewType
import numpy
import onnx.checker
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx import IR_VERSION, ModelProto, NodeProto
[docs]
class DeviceType:
"""Describes device type."""
_Type = NewType("_Type", int)
CPU: _Type = _Type(0)
CUDA: _Type = _Type(1)
[docs]
class Device:
"""Describes device type and device id
syntax: device_type:device_id(optional)
example: 'CPU', 'CUDA', 'CUDA:1'
"""
def __init__(self, device: str) -> None:
options = device.split(":")
self.type = getattr(DeviceType, options[0])
self.device_id = 0
if len(options) > 1:
self.device_id = int(options[1])
def namedtupledict(
typename: str, field_names: Sequence[str], *args: Any, **kwargs: Any
) -> type[tuple[Any, ...]]:
field_names_map = {n: i for i, n in enumerate(field_names)}
# Some output names are invalid python identifier, e.g. "0"
kwargs.setdefault("rename", True)
data = namedtuple(typename, field_names, *args, **kwargs) # type: ignore # noqa: PYI024
def getitem(self: Any, key: Any) -> Any:
if isinstance(key, str):
key = field_names_map[key]
return super(type(self), self).__getitem__(key) # type: ignore
data.__getitem__ = getitem # type: ignore[assignment]
return data
[docs]
class BackendRep:
"""BackendRep is the handle that a Backend returns after preparing to execute
a model repeatedly. Users will then pass inputs to the run function of
BackendRep to retrieve the corresponding results.
"""
[docs]
def run(self, inputs: Any, **kwargs: Any) -> tuple[Any, ...]: # noqa: ARG002
"""Abstract function."""
return (None,)
[docs]
class Backend:
"""Backend is the entity that will take an ONNX model with inputs,
perform a computation, and then return the output.
For one-off execution, users can use run_node and run_model to obtain results quickly.
For repeated execution, users should use prepare, in which the Backend
does all of the preparation work for executing the model repeatedly
(e.g., loading initializers), and returns a BackendRep handle.
"""
@classmethod
def is_compatible(
cls,
model: ModelProto, # noqa: ARG003
device: str = "CPU", # noqa: ARG003
**kwargs: Any, # noqa: ARG003
) -> bool:
# Return whether the model is compatible with the backend.
return True
@classmethod
def prepare(
cls,
model: ModelProto,
device: str = "CPU", # noqa: ARG003
**kwargs: Any, # noqa: ARG003
) -> BackendRep | None:
# TODO Remove Optional from return type
onnx.checker.check_model(model)
return None
@classmethod
def run_model(
cls, model: ModelProto, inputs: Any, device: str = "CPU", **kwargs: Any
) -> tuple[Any, ...]:
backend = cls.prepare(model, device, **kwargs)
assert backend is not None
return backend.run(inputs)
[docs]
@classmethod
def run_node(
cls,
node: NodeProto,
inputs: Any, # noqa: ARG003
device: str = "CPU", # noqa: ARG003
outputs_info: ( # noqa: ARG003
Sequence[tuple[numpy.dtype, tuple[int, ...]]] | None
) = None,
**kwargs: dict[str, Any],
) -> tuple[Any, ...] | None:
"""Simple run one operator and return the results.
Args:
node: The node proto.
inputs: Inputs to the node.
device: The device to run on.
outputs_info: a list of tuples, which contains the element type and
shape of each output. First element of the tuple is the dtype, and
the second element is the shape. More use case can be found in
https://github.com/onnx/onnx/blob/main/onnx/backend/test/runner/__init__.py
kwargs: Other keyword arguments.
"""
# TODO Remove Optional from return type
if "opset_version" in kwargs:
special_context = c_checker.CheckerContext()
special_context.ir_version = IR_VERSION
special_context.opset_imports = {"": kwargs["opset_version"]} # type: ignore
onnx.checker.check_node(node, special_context)
else:
onnx.checker.check_node(node)
return None
[docs]
@classmethod
def supports_device(cls, device: str) -> bool: # noqa: ARG003
"""Checks whether the backend is compiled with particular device support.
In particular it's used in the testing suite.
"""
return True