Source code for skl2onnx.proto
# SPDX-License-Identifier: Apache-2.0
# Rather than using ONNX protobuf definition throughout our codebase,
# we import ONNX protobuf definition here so that we can conduct quick
# fixes by overwriting ONNX functions without changing any lines
# elsewhere.
from onnx import onnx_pb as onnx_proto # noqa
from onnx import defs # noqa
# Overwrite the make_tensor defined in onnx.helper because of a bug
# (string tensor get assigned twice)
from onnx import mapping
from onnx.onnx_pb import TensorProto, ValueInfoProto # noqa
try:
from onnx.onnx_pb import SparseTensorProto # noqa
except ImportError:
# onnx is too old.
pass
from onnx.helper import split_complex_to_pairs
def make_tensor_fixed(name, data_type, dims, vals, raw=False):
"""
Make a TensorProto with specified arguments. If raw is False, this
function will choose the corresponding proto field to store the
values based on data_type. If raw is True, use "raw_data" proto
field to store the values, and values should be of type bytes in
this case.
"""
tensor = TensorProto()
tensor.data_type = data_type
tensor.name = name
if data_type == TensorProto.COMPLEX64 or data_type == TensorProto.COMPLEX128:
vals = split_complex_to_pairs(vals)
if raw:
tensor.raw_data = vals
else:
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type]
]
getattr(tensor, field).extend(vals)
tensor.dims.extend(dims)
return tensor
def get_opset_number_from_onnx():
"""
Returns the latest opset version supported
by the *onnx* package.
"""
return defs.onnx_opset_version()
[docs]def get_latest_tested_opset_version():
"""
This module relies on *onnxruntime* to test every
converter. The function returns the most recent
target opset tested with *onnxruntime* or the opset
version specified by *onnx* package if this one is lower
(return by `onnx.defs.onnx_opset_version()`).
"""
from .. import __max_supported_opset__
return min(__max_supported_opset__, get_opset_number_from_onnx())