Source code for onnx.tools.update_model_dims

# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from typing import Any

import onnx.checker
from onnx import ModelProto, ValueInfoProto


[docs] def update_inputs_outputs_dims( model: ModelProto, input_dims: dict[str, list[Any]], output_dims: dict[str, list[Any]], ) -> ModelProto: """This function updates the dimension sizes of the model's inputs and outputs to the values provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param will be set for that dimension. Example. if we have the following shape for inputs and outputs: * shape(input_1) = ('b', 3, 'w', 'h') * shape(input_2) = ('b', 4) * shape(output) = ('b', 'd', 5) The parameters can be provided as: :: input_dims = { "input_1": ['b', 3, 'w', 'h'], "input_2": ['b', 4], } output_dims = { "output": ['b', -1, 5] } Putting it together: :: model = onnx.load('model.onnx') updated_model = update_inputs_outputs_dims(model, input_dims, output_dims) onnx.save(updated_model, 'model.onnx') """ dim_param_set: set[str] = set() def init_dim_param_set( dim_param_set: set[str], value_infos: list[ValueInfoProto] ) -> None: for info in value_infos: shape = info.type.tensor_type.shape for dim in shape.dim: if dim.HasField("dim_param"): dim_param_set.add(dim.dim_param) # type: ignore init_dim_param_set(dim_param_set, model.graph.input) # type: ignore init_dim_param_set(dim_param_set, model.graph.output) # type: ignore init_dim_param_set(dim_param_set, model.graph.value_info) # type: ignore def update_dim(tensor: ValueInfoProto, dim: Any, j: int, name: str) -> None: dim_proto = tensor.type.tensor_type.shape.dim[j] if isinstance(dim, int): if dim >= 0: if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim: raise ValueError( f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}." ) dim_proto.dim_value = dim else: generated_dim_param = name + "_" + str(j) if generated_dim_param in dim_param_set: raise ValueError( f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value." ) dim_proto.dim_param = generated_dim_param elif isinstance(dim, str): dim_proto.dim_param = dim else: raise ValueError( # noqa: TRY004 f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}" ) for input_ in model.graph.input: input_name = input_.name input_dim_arr = input_dims[input_name] for j, dim in enumerate(input_dim_arr): update_dim(input_, dim, j, input_name) for output in model.graph.output: output_name = output.name output_dim_arr = output_dims[output_name] for j, dim in enumerate(output_dim_arr): update_dim(output, dim, j, output_name) onnx.checker.check_model(model) return model