# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
__all__ = [
"OpSignature",
"Parameter",
"AttributeParameter",
"TypeConstraintParam",
]
import dataclasses
import functools
from collections.abc import Iterator, Mapping, Sequence
from typing import Any
import onnx # noqa: TID251
from onnx_ir import _core, _enums, _protocols, serde
# A special value to indicate that the default value is not specified
class _Empty:
def __repr__(self) -> str:
return "_EMPTY_DEFAULT"
_EMPTY_DEFAULT = _Empty()
@functools.cache
def _all_value_types():
return frozenset(
{_core.TensorType(dtype) for dtype in _enums.DataType}
| {_core.SequenceType(_core.TensorType(dtype)) for dtype in _enums.DataType}
| {_core.OptionalType(_core.TensorType(dtype)) for dtype in _enums.DataType}
)
[docs]
@dataclasses.dataclass(frozen=True)
class TypeConstraintParam:
"""Type constraint for a parameter.
Attributes:
name: Name of the parameter. E.g. "TFloat"
allowed_types: Allowed types for the parameter.
description: Human-readable description of the type constraint.
"""
name: str
allowed_types: frozenset[_protocols.TypeProtocol]
description: str = ""
def __post_init__(self):
if not self.allowed_types:
raise ValueError(
f"Type constraint '{self.name}' must have at least one allowed type."
)
if not isinstance(self.allowed_types, frozenset):
object.__setattr__(self, "allowed_types", frozenset(self.allowed_types))
def __str__(self) -> str:
allowed_types_str = " | ".join(str(t) for t in self.allowed_types)
return f"{self.name}={allowed_types_str}"
[docs]
@classmethod
def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(
name, frozenset(_core.TensorType(dtype) for dtype in _enums.DataType), description
)
[docs]
@classmethod
def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
return cls(name, _all_value_types(), description) # type: ignore[arg-type]
[docs]
@dataclasses.dataclass(frozen=True)
class Parameter:
"""A formal parameter of an operator."""
name: str
type_constraint: TypeConstraintParam
required: bool
variadic: bool
homogeneous: bool = True
min_arity: int = 1
# TODO: Add differentiation_category
default: Any = _EMPTY_DEFAULT
def __str__(self) -> str:
type_str = self.type_constraint.name
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"
[docs]
def has_default(self) -> bool:
return self.default is not _EMPTY_DEFAULT
[docs]
def is_param(self) -> bool:
"""This parameter is an ONNX input or output parameter, as opposed to an ONNX attribute parameter."""
return True
[docs]
def is_attribute(self) -> bool:
return False
[docs]
@dataclasses.dataclass(frozen=True)
class AttributeParameter:
"""A parameter in the function signature that represents an ONNX attribute."""
name: str
type: _enums.AttributeType
required: bool
default: _core.Attr | None = None
def __str__(self) -> str:
type_str = self.type.name
if self.has_default():
return f"{self.name}: {type_str} = {self.default}"
return f"{self.name}: {type_str}"
[docs]
def has_default(self) -> bool:
return self.default is not None
[docs]
def is_param(self) -> bool:
return False
[docs]
def is_attribute(self) -> bool:
"""This parameter is an ONNX attribute parameter, as opposed to an ONNX input or output parameter."""
return True
def _get_type_from_str(
type_str: str,
) -> _core.TensorType | _core.SequenceType | _core.OptionalType:
"""Convert a type_str from ONNX OpSchema to _protocols.TypeProtocol.
A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
"""
# Split the type_str into sequence types and dtypes
# 1. Remove the ending ")"
stripped = type_str.rstrip(")")
# 2. Split the type_str by "("
type_parts = stripped.split("(")
# Convert the dtype to _enums.DataType
dtype = _enums.DataType[type_parts[-1].upper()]
# Create a place holder type first
type_: _protocols.TypeProtocol = _core.TensorType(_enums.DataType.UNDEFINED)
# Construct the type
for type_part in reversed(type_parts[:-1]):
if type_part == "tensor":
type_ = _core.TensorType(dtype)
elif type_part == "seq":
type_ = _core.SequenceType(type_)
elif type_part == "optional":
type_ = _core.OptionalType(type_)
else:
raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'")
return type_ # type: ignore[return-value]
def _convert_formal_parameter(
param: onnx.defs.OpSchema.FormalParameter,
type_constraints: Mapping[str, TypeConstraintParam],
) -> Parameter:
"""Convert a formal parameter from ONNX OpSchema to Parameter."""
if param.type_str in type_constraints:
type_constraint = type_constraints[param.type_str]
else:
# param.type_str can be a plain type like 'int64'.
type_constraint = TypeConstraintParam(
name=param.name,
allowed_types=frozenset((_get_type_from_str(param.type_str),)),
)
return Parameter(
name=param.name,
type_constraint=type_constraint,
required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional,
variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic,
homogeneous=param.is_homogeneous,
min_arity=param.min_arity,
)
[docs]
@dataclasses.dataclass
class OpSignature:
"""Schema for an operator.
Attributes:
domain: Domain of the operator. E.g. "".
name: Name of the operator. E.g. "Add".
overload: Overload name of the operator.
params: Input parameters. When the op is an ONNX function definition,
the order is according to the function signature. This mean we can
interleave ONNX inputs and ONNX attributes in the list.
outputs: Output parameters.
since_version: The version of the operator set. E.g. 1.
"""
domain: str
name: str
overload: str
params: Sequence[Parameter | AttributeParameter]
outputs: Sequence[Parameter]
params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field(
init=False, repr=False
)
since_version: int = 1
def __post_init__(self):
params_map: dict[str, Parameter | AttributeParameter] = {}
for param in self.params:
if param.name in params_map:
raise ValueError(
f"Duplicate parameter name {param.name!r} in OpSignature "
f"{self.domain!r}::{self.name!r}"
)
params_map[param.name] = param
self.params_map = params_map
[docs]
def get(
self,
name: str,
default: Parameter | AttributeParameter | None = None,
) -> Parameter | AttributeParameter | None:
return self.params_map.get(name, default)
def __contains__(self, name: str) -> bool:
return name in self.params_map
def __iter__(self) -> Iterator[Parameter | AttributeParameter]:
return iter(self.params)
def __str__(self) -> str:
domain = self.domain or "''"
overload = f"::{self.overload}" if self.overload else ""
params = ", ".join(str(param) for param in self.params)
outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs)
type_constraints = {}
for param in self.params:
if isinstance(param, Parameter):
type_constraints[param.type_constraint.name] = param.type_constraint
for param in self.outputs:
type_constraints[param.type_constraint.name] = param.type_constraint
type_constraints_str = ", ".join(
str(type_constraint) for type_constraint in type_constraints.values()
)
return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}"
@property
def inputs(self) -> Sequence[Parameter]:
"""Returns the input parameters."""
return [param for param in self.params if isinstance(param, Parameter)]
@property
def attributes(self) -> Sequence[AttributeParameter]:
"""Returns the attribute parameters."""
return [param for param in self.params if isinstance(param, AttributeParameter)]
[docs]
@classmethod
def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature:
"""Produce an OpSignature from an ONNX OpSchema."""
type_constraints = {
constraint.type_param_str: TypeConstraintParam(
name=constraint.type_param_str,
allowed_types=frozenset(
_get_type_from_str(type_str) for type_str in constraint.allowed_type_strs
),
description=constraint.description,
)
for constraint in op_schema.type_constraints
}
params = [
_convert_formal_parameter(param, type_constraints) for param in op_schema.inputs
]
for param in op_schema.attributes.values():
default_attr = (
serde.deserialize_attribute(param.default_value)
if param.default_value is not None
else None
)
if default_attr is not None:
# Set the name of the default attribute because it may have a different name from the parameter
default_attr.name = param.name
params.append(
AttributeParameter(
name=param.name,
type=_enums.AttributeType(param.type), # type: ignore[arg-type]
required=param.required,
default=default_attr, # type: ignore[arg-type]
)
)
outputs = [
_convert_formal_parameter(param, type_constraints) for param in op_schema.outputs
]
return cls(
domain=op_schema.domain,
name=op_schema.name,
overload="",
params=params,
outputs=outputs,
since_version=op_schema.since_version,
)