Source code for onnx_ir.passes.common.initializer_deduplication

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Pass for removing duplicated initializer tensors from a graph."""

from __future__ import annotations

__all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]


import hashlib
import logging

import onnx_ir as ir

logger = logging.getLogger(__name__)


def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
    """Check if the initializer should be skipped for deduplication."""
    if initializer.is_graph_input() or initializer.is_graph_output():
        # Skip graph inputs and outputs
        logger.warning(
            "Skipped deduplication of initializer '%s' as it is a graph input or output",
            initializer.name,
        )
        return True

    const_val = initializer.const_value
    if const_val is None:
        # Skip if initializer has no constant value
        logger.warning(
            "Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
            initializer.name,
        )
        return True

    if const_val.size > size_limit:
        # Skip if the initializer is larger than the size limit
        logger.debug(
            "Skipped initializer '%s' as it exceeds the size limit of %d elements",
            initializer.name,
            size_limit,
        )
        return True

    if const_val.dtype == ir.DataType.STRING:
        # Skip string initializers as they don't have a bytes representation
        logger.warning(
            "Skipped deduplication of string initializer '%s' (unsupported yet)",
            initializer.name,
        )
        return True
    return False


[docs] class DeduplicateInitializersPass(ir.passes.InPlacePass): """Remove duplicated initializer tensors from the main graph and all subgraphs. This pass detects initializers with identical shape, dtype, and content, and replaces all duplicate references with a canonical one. Initializers are deduplicated within each graph. To deduplicate initializers in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass` to lift the initializers to the main graph first before running pass. .. versionadded:: 0.1.3 .. versionchanged:: 0.1.7 This pass now deduplicates initializers in subgraphs as well. """ def __init__(self, size_limit: int = 1024): super().__init__() self.size_limit = size_limit def call(self, model: ir.Model) -> ir.passes.PassResult: modified = False for graph in model.graphs(): initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {} for initializer in tuple(graph.initializers.values()): if _should_skip_initializer(initializer, self.size_limit): continue const_val = initializer.const_value assert const_val is not None key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes()) if key in initializers: modified = True initializer_to_keep = initializers[key] # type: ignore[index] ir.convenience.replace_all_uses_with(initializer, initializer_to_keep) assert initializer.name is not None graph.initializers.pop(initializer.name) logger.info( "Replaced initializer '%s' with existing initializer '%s'", initializer.name, initializer_to_keep.name, ) else: initializers[key] = initializer # type: ignore[index] return ir.passes.PassResult(model=model, modified=modified)
[docs] class DeduplicateHashedInitializersPass(ir.passes.InPlacePass): """Remove duplicated initializer tensors (using a hashed method) from the graph. This pass detects initializers with identical shape, dtype, and hashed content, and replaces all duplicate references with a canonical one. This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass` as it does not store the full tensor data in memory, but instead uses a hash of the tensor data. .. versionadded:: 0.1.7 """ def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024): super().__init__() # 4 GB default size limit for deduplication self.size_limit = size_limit def call(self, model: ir.Model) -> ir.passes.PassResult: modified = False for graph in model.graphs(): initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {} for initializer in tuple(graph.initializers.values()): if _should_skip_initializer(initializer, self.size_limit): continue const_val = initializer.const_value assert const_val is not None # Hash tensor data to avoid storing large amounts of data in memory hashed = hashlib.sha512() tensor_data = const_val.numpy() hashed.update(tensor_data) tensor_digest = hashed.hexdigest() tensor_dims = tuple(const_val.shape.numpy()) key = (const_val.dtype, tensor_dims, tensor_digest) if key in initializers: if initializers[key].const_value.tobytes() != const_val.tobytes(): logger.warning( "Initializer deduplication failed: " "hashes match but values differ with values %s and %s", initializers[key], initializer, ) continue modified = True initializer_to_keep = initializers[key] # type: ignore[index] ir.convenience.replace_all_uses_with(initializer, initializer_to_keep) assert initializer.name is not None graph.initializers.pop(initializer.name) logger.info( "Replaced initializer '%s' with existing initializer '%s'", initializer.name, initializer_to_keep.name, ) else: initializers[key] = initializer # type: ignore[index] return ir.passes.PassResult(model=model, modified=modified)