# SPDX-License-Identifier: Apache-2.0
"""
Place holder for all ONNX operators.
"""
import sys
import textwrap
from sklearn.pipeline import Pipeline, FeatureUnion
try:
from sklearn.compose import ColumnTransformer
except ImportError:
# ColumnTransformer was introduced in 0.20.
ColumnTransformer = None
from .onnx_subgraph_operator_mixin import OnnxSubGraphOperatorMixin
def ClassFactorySklearn(skl_obj, class_name, doc, conv, shape_calc, alias):
from .onnx_subgraph_operator_mixin import OnnxSubGraphOperatorMixin
newclass = type(
class_name,
(OnnxSubGraphOperatorMixin, skl_obj),
{
"__doc__": doc,
"operator_name": skl_obj.__name__,
"_fct_converter": conv,
"_fct_shape_calc": shape_calc,
"input_range": [1, 1e9],
"output_range": [1, 1e9],
"op_version": None,
"alias": alias,
"__module__": __name__,
},
)
return newclass
def dynamic_class_creation_sklearn():
"""
Automatically generates classes for each of the converter.
"""
from ..common._registration import _shape_calculator_pool, _converter_pool
from .._supported_operators import sklearn_operator_name_map
cls = {}
for skl_obj, name in sklearn_operator_name_map.items():
if skl_obj is None:
continue
conv = _converter_pool[name]
shape_calc = _shape_calculator_pool[name]
skl_name = skl_obj.__name__
doc = ["OnnxOperatorMixin for **{}**".format(skl_name), ""]
if conv.__doc__:
doc.append(textwrap.dedent(conv.__doc__))
doc = "\n".join(doc)
prefix = "Sklearn" if "sklearn" in str(skl_obj) else ""
class_name = "Onnx" + prefix + skl_name
try:
cl = ClassFactorySklearn(skl_obj, class_name, doc, conv, shape_calc, name)
except TypeError:
continue
cls[class_name] = cl
return cls
def _update_module():
"""
Dynamically updates the module with operators defined
by *ONNX*.
"""
res = dynamic_class_creation_sklearn()
this = sys.modules[__name__]
for k, v in res.items():
setattr(this, k, v)
def find_class(skl_cl):
"""
Finds the corresponding :class:`OnnxSubGraphOperatorMixin`
class to *skl_cl*.
"""
name = skl_cl.__name__
prefix = "OnnxSklearn"
full_name = prefix + name
this = sys.modules[__name__]
if not hasattr(this, full_name):
available = sorted(filter(lambda n: prefix in n, sys.modules))
raise RuntimeError(
"Unable to find a class for '{}' in\n{}".format(
skl_cl.__name__, "\n".join(available)
)
)
cl = getattr(this, full_name)
if "automation" in str(cl):
raise RuntimeError(
"Dynamic operation issue with class "
"name '{}' from '{}'.".format(cl, __name__)
)
return cl
[docs]
class OnnxSklearnPipeline(Pipeline, OnnxSubGraphOperatorMixin):
"""
Combines `Pipeline
<https://scikit-learn.org/stable/modules/generated/
sklearn.pipeline.Pipeline.html>`_ and
:class:`OnnxSubGraphOperatorMixin`.
"""
def __init__(self, steps, memory=None, verbose=False, op_version=None):
Pipeline.__init__(self, steps=steps, memory=memory, verbose=verbose)
OnnxSubGraphOperatorMixin.__init__(self)
self.op_version = op_version
if ColumnTransformer is not None:
[docs]
class OnnxSklearnFeatureUnion(FeatureUnion, OnnxSubGraphOperatorMixin):
"""
Combines `FeatureUnion
<https://scikit-learn.org/stable/modules/generated/
sklearn.pipeline.FeatureUnion.html>`_ and
:class:`OnnxSubGraphOperatorMixin`.
"""
def __init__(self, op_version=None):
self.op_version = op_version
_update_module()