Source code for onnx_ir.passes.common.initializer_deduplication
# Copyright (c) ONNX Project Contributors# SPDX-License-Identifier: Apache-2.0"""Pass for removing duplicated initializer tensors from a graph."""from__future__importannotations__all__=["DeduplicateInitializersPass",]importonnx_irasir
[docs]classDeduplicateInitializersPass(ir.passes.InPlacePass):"""Remove duplicated initializer tensors from the graph. This pass detects initializers with identical shape, dtype, and content, and replaces all duplicate references with a canonical one. To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass` to lift the initializers to the main graph first before running pass. """def__init__(self,size_limit:int=1024):super().__init__()self.size_limit=size_limitdefcall(self,model:ir.Model)->ir.passes.PassResult:graph=model.graphinitializers:dict[tuple[ir.DataType,tuple[int,...],bytes],ir.Value]={}modified=Falseforinitializerintuple(graph.initializers.values()):# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers# out from the main graph before running this pass.const_val=initializer.const_valueifconst_valisNone:# Skip if initializer has no constant valuecontinueifconst_val.size>self.size_limit:continuekey=(const_val.dtype,tuple(const_val.shape),const_val.tobytes())ifkeyininitializers:modified=Trueir.convenience.replace_all_uses_with(initializer,initializers[key])# type: ignore[index]assertinitializer.nameisnotNonegraph.initializers.pop(initializer.name)else:initializers[key]=initializer# type: ignore[index]returnir.passes.PassResult(model=model,modified=modified)