Source code for onnxconverter_common.utils
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
###############################################################################
import numbers
import numpy as np
import warnings
import packaging.version as pv
def sparkml_installed():
"""
Checks that *spark* is available.
"""
try:
import pyspark # noqa F401
return True
except ImportError:
return False
def sklearn_installed():
"""
Checks that *scikit-learn* is available.
"""
try:
import sklearn # noqa F401
return True
except ImportError:
return False
def skl2onnx_installed():
"""
Checks that *skl2onnx* converter is available.
"""
try:
import skl2onnx # noqa F401
return True
except ImportError:
return False
def coreml_installed():
"""
Checks that *coremltools* is available.
"""
try:
import coremltools # noqa F401
return True
except ImportError:
return False
def keras2onnx_installed():
"""
Checks that *keras2onnx* is available.
"""
try:
import keras2onnx # noqa F401
return True
except ImportError:
return False
def torch_installed():
"""
Checks that *pytorch* is available.
"""
try:
import torch # noqa F401
return True
except ImportError:
return False
def caffe2_installed():
"""
Checks that *caffe* is available.
"""
try:
import caffe2 # noqa F401
return True
except ImportError:
return False
def libsvm_installed():
"""
Checks that *libsvm* is available.
"""
try:
import svm # noqa F401
import svmutil # noqa F401
return True
except ImportError:
return False
def lightgbm_installed():
"""
Checks that *lightgbm* is available.
"""
try:
import lightgbm # noqa F401
return True
except ImportError:
return False
def xgboost_installed():
"""
Checks that *xgboost* is available.
"""
try:
import xgboost # noqa F401
except ImportError:
return False
from xgboost.core import _LIB
try:
_LIB.XGBoosterDumpModelEx
except AttributeError:
# The version is not recent enough even though it is version 0.6.
# You need to install xgboost from github and not from pypi.
return False
from xgboost import __version__
vers = pv.Version(__version__)
allowed = pv.Version('0.7')
if vers < allowed:
warnings.warn('The converter works for xgboost >= 0.7. Earlier versions might not.')
return True
def h2o_installed():
"""
Checks that *h2o* is available.
"""
try:
import h2o # noqa F401
except ImportError:
return False
return True
def hummingbird_installed():
"""
Checks that *Hummingbird* is available.
"""
try:
import hummingbird.ml # noqa: F401
return True
except ImportError:
return False
def get_producer():
"""
Internal helper function to return the producer
"""
from . import __producer__
return __producer__
def get_producer_version():
"""
Internal helper function to return the producer version
"""
from . import __producer_version__
return __producer_version__
def get_domain():
"""
Internal helper function to return the model domain
"""
from . import __domain__
return __domain__
def get_model_version():
"""
Internal helper function to return the model version
"""
from . import __model_version__
return __model_version__
def is_numeric_type(item):
numeric_types = (int, float, complex)
types = numeric_types
if isinstance(item, list):
return all(isinstance(i, types) for i in item)
if isinstance(item, np.ndarray):
return np.issubdtype(item.dtype, np.number)
return isinstance(item, types)
def is_string_type(item):
if isinstance(item, list):
return all(isinstance(i, str) for i in item)
if isinstance(item, np.ndarray):
return np.issubdtype(item.dtype, np.str_)
return isinstance(item, str)
def cast_list(type, items):
return [type(item) for item in items]
def convert_to_python_value(var):
if isinstance(var, numbers.Integral):
return int(var)
elif isinstance(var, numbers.Real):
return float(var)
elif isinstance(var, str):
return str(var)
else:
raise TypeError('Unable to convert {0} to python type'.format(type(var)))
def convert_to_python_default_value(var):
if isinstance(var, numbers.Integral):
return int()
elif isinstance(var, numbers.Real):
return float()
elif isinstance(var, str):
return str()
else:
raise TypeError('Unable to find default python value for type {0}'.format(type(var)))
def convert_to_list(var):
if isinstance(var, numbers.Real) or isinstance(var, str):
return [convert_to_python_value(var)]
elif isinstance(var, np.ndarray) and len(var.shape) == 1:
return [convert_to_python_value(v) for v in var]
elif isinstance(var, list):
flattened = []
if all(isinstance(ele, np.ndarray) and len(ele.shape) == 1 for ele in var):
max_classes = max([ele.shape[0] for ele in var])
flattened_one = []
for ele in var:
for i in range(max_classes):
if i < ele.shape[0]:
flattened_one.append(convert_to_python_value(ele[i]))
else:
flattened_one.append(convert_to_python_default_value(ele[0]))
flattened += flattened_one
return flattened
elif all(isinstance(v, numbers.Real) or isinstance(v, str) for v in var):
return [convert_to_python_value(v) for v in var]
else:
raise TypeError('Unable to flatten variable')
else:
raise TypeError('Unable to flatten variable')
[docs]
def check_input_and_output_numbers(operator, input_count_range=None, output_count_range=None):
'''
Check if the number of input(s)/output(s) is correct
:param operator: A Operator object
:param input_count_range: A list of two integers or an integer.
If it's a list the first/second element is the
minimal/maximal number of inputs. If it's an integer,
it is equivalent to specify that number twice in a list. For
infinite ranges like 5 to infinity, you need to use [5, None].
:param output_count_range: A list of two integers or an integer.
See input_count_range for its format.
'''
if isinstance(input_count_range, list):
min_input_count = input_count_range[0]
max_input_count = input_count_range[1]
elif isinstance(input_count_range, int) or input_count_range is None:
min_input_count = input_count_range
max_input_count = input_count_range
else:
raise RuntimeError('input_count_range must be a list or an integer')
if isinstance(output_count_range, list):
min_output_count = output_count_range[0]
max_output_count = output_count_range[1]
elif isinstance(output_count_range, int) or output_count_range is None:
min_output_count = output_count_range
max_output_count = output_count_range
else:
raise RuntimeError('output_count_range must be a list or an integer')
if min_input_count is not None and len(operator.inputs) < min_input_count:
raise RuntimeError(
'For operator %s (type: %s), at least %s input(s) is(are) required but we got %s input(s) which are %s'
% (operator.full_name, operator.type, min_input_count, len(operator.inputs), operator.input_full_names))
if max_input_count is not None and len(operator.inputs) > max_input_count:
raise RuntimeError(
'For operator %s (type: %s), at most %s input(s) is(are) supported but we got %s input(s) which are %s'
% (operator.full_name, operator.type, max_input_count, len(operator.inputs), operator.input_full_names))
if min_output_count is not None and len(operator.outputs) < min_output_count:
raise RuntimeError(
'For operator %s (type: %s), at least %s output(s) is(are) produced but we got %s output(s) which are %s'
% (operator.full_name, operator.type, min_output_count, len(operator.outputs), operator.output_full_names))
if max_output_count is not None and len(operator.outputs) > max_output_count:
raise RuntimeError(
'For operator %s (type: %s), at most %s outputs(s) is(are) supported but we got %s output(s) which are %s'
% (operator.full_name, operator.type, max_output_count, len(operator.outputs), operator.output_full_names))
[docs]
def check_input_and_output_types(operator, good_input_types=None, good_output_types=None):
'''
Check if the type(s) of input(s)/output(s) is(are) correct
:param operator: A Operator object
:param good_input_types: A list of allowed input types
(e.g., [FloatTensorType, Int64TensorType]) or None. None
means that we skip the check of the input types.
:param good_output_types: A list of allowed output types.
See good_input_types for its format.
'''
if good_input_types is not None:
for variable in operator.inputs:
if type(variable.type) not in good_input_types:
raise RuntimeError('Operator %s (type: %s) got an input %s with a wrong type %s. Only %s are allowed'
% (operator.full_name, operator.type, variable.full_name, type(variable.type),
good_input_types))
if good_output_types is not None:
for variable in operator.outputs:
if type(variable.type) not in good_output_types:
raise RuntimeError('Operator %s (type: %s) got an output %s with a wrong type %s. Only %s are allowed'
% (operator.full_name, operator.type, variable.full_name, type(variable.type),
good_output_types))