Source code for onnx_ir.passes.common.default_attributes
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Add default attributes to nodes that are missing optional attributes."""
from __future__ import annotations
__all__ = [
"AddDefaultAttributesPass",
]
import logging
import onnx # noqa: TID251
import onnx_ir as ir
logger = logging.getLogger(__name__)
def _has_valid_default(attr_def: onnx.defs.OpSchema.Attribute) -> bool:
"""Check if an attribute definition has a valid default value."""
return bool(
attr_def.default_value and attr_def.default_value.type != onnx.AttributeProto.UNDEFINED
)
[docs]
class AddDefaultAttributesPass(ir.passes.InPlacePass):
"""Add default values for optional attributes that are not present in nodes.
This pass iterates through all nodes in the model and for each node:
1. Gets the ONNX schema for the operator
2. For each optional attribute with a default value in the schema
3. If the attribute is not present in the node, adds it with the default value
"""
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Main entry point for the add default attributes pass."""
modified = False
# Process all nodes in the model graph and subgraphs
for node in ir.traversal.RecursiveGraphIterator(model.graph):
if _add_default_attributes_to_node(node, model.graph.opset_imports):
modified = True
# Process nodes in functions
for function in model.functions.values():
for node in ir.traversal.RecursiveGraphIterator(function):
if _add_default_attributes_to_node(node, model.graph.opset_imports):
modified = True
if modified:
logger.info("AddDefaultAttributes pass modified the model")
return ir.passes.PassResult(model, modified=modified)
def _add_default_attributes_to_node(node: ir.Node, opset_imports: dict[str, int]) -> bool:
"""Add default attributes to a single node. Returns True if modified."""
# Get the operator schema
if node.version is not None:
opset_version = node.version
elif node.domain in opset_imports:
opset_version = opset_imports[node.domain]
else:
logger.warning(
"OpSet version for domain '%s' not found. Skipping node %s",
node.domain,
node,
)
return False
try:
op_schema = onnx.defs.get_schema(node.op_type, opset_version, domain=node.domain)
except onnx.defs.SchemaError:
logger.debug(
"Schema not found for %s, skipping default attribute addition",
node,
)
return False
modified = False
# Iterate through all attributes in the schema
for attr_name, attr_def in op_schema.attributes.items():
# Skip if attribute is required or already present in the node
if attr_def.required or attr_name in node.attributes:
continue
# Skip if attribute doesn't have a default value
if not _has_valid_default(attr_def):
continue
# Create an IR Attr from the ONNX AttributeProto default value
default_attr_proto = attr_def.default_value
default_attr = ir.serde.deserialize_attribute(default_attr_proto)
node.attributes[attr_name] = default_attr
logger.debug("Added default attribute '%s' to node %s", attr_name, node)
modified = True
return modified