.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_custom_parser_alternative.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_custom_parser_alternative.py: .. _l-custom-parser-alternative: When a custom model is neither a classifier nor a regressor (alternative) ========================================================================= .. note:: This example rewrites :ref:`l-custom-parser` by using the syntax proposed in example :ref:`l-onnx-operators` to write the custom converter, shape calculator and parser. *scikit-learn*'s API specifies that a regressor produces one outputs and a classifier produces two outputs, predicted labels and probabilities. The goal here is to add a third result which tells if the probability is above a given threshold. That's implemented in method *validate*. Iris and scoring ++++++++++++++++ A new class is created, it trains any classifier and implements the method *validate* mentioned above. .. GENERATED FROM PYTHON SOURCE LINES 28-93 .. code-block:: Python import inspect import numpy as np import skl2onnx import onnx import sklearn from sklearn.base import ClassifierMixin, BaseEstimator, clone from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from skl2onnx import update_registered_converter import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnxruntime as rt from skl2onnx import to_onnx, get_model_alias from skl2onnx.proto import onnx_proto from skl2onnx.common.data_types import FloatTensorType, Int64TensorType from skl2onnx.algebra.onnx_ops import ( OnnxGreater, OnnxCast, OnnxReduceMaxApi18, OnnxIdentity, ) from skl2onnx.algebra.onnx_operator import OnnxSubEstimator import matplotlib.pyplot as plt class ValidatorClassifier(BaseEstimator, ClassifierMixin): def __init__(self, estimator=None, threshold=0.75): ClassifierMixin.__init__(self) BaseEstimator.__init__(self) if estimator is None: estimator = LogisticRegression(solver="liblinear") self.estimator = estimator self.threshold = threshold def fit(self, X, y, sample_weight=None): sig = inspect.signature(self.estimator.fit) if "sample_weight" in sig.parameters: self.estimator_ = clone(self.estimator).fit( X, y, sample_weight=sample_weight ) else: self.estimator_ = clone(self.estimator).fit(X, y) return self def predict(self, X): return self.estimator_.predict(X) def predict_proba(self, X): return self.estimator_.predict_proba(X) def validate(self, X): pred = self.predict_proba(X) mx = pred.max(axis=1) return (mx >= self.threshold) * 1 data = load_iris() X, y = data.data, data.target X_train, X_test, y_train, y_test = train_test_split(X, y) model = ValidatorClassifier() model.fit(X_train, y_train) .. raw:: html
ValidatorClassifier(estimator=LogisticRegression(solver='liblinear'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 94-97 Let's now measure the indicator which tells if the probability of a prediction is above a threshold. .. GENERATED FROM PYTHON SOURCE LINES 97-100 .. code-block:: Python print(model.validate(X_test)) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 1 1 0 0 1 1 0 1 1 1 0 0 1 1 1 1 1 0 1 0 0 1 0 0 0 1 1 1 1 1 0 1 1 1 0 0 0] .. GENERATED FROM PYTHON SOURCE LINES 101-107 Conversion to ONNX +++++++++++++++++++ The conversion fails for a new model because the library does not know any converter associated to this new model. .. GENERATED FROM PYTHON SOURCE LINES 107-113 .. code-block:: Python try: to_onnx(model, X_train[:1].astype(np.float32), target_opset=12) except RuntimeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none Unable to find a shape calculator for type ''. It usually means the pipeline being converted contains a transformer or a predictor with no corresponding converter implemented in sklearn-onnx. If the converted is implemented in another library, you need to register the converted so that it can be used by sklearn-onnx (function update_registered_converter). If the model is not yet covered by sklearn-onnx, you may raise an issue to https://github.com/onnx/sklearn-onnx/issues to get the converter implemented or even contribute to the project. If the model is a custom model, a new converter must be implemented. Examples can be found in the gallery. .. GENERATED FROM PYTHON SOURCE LINES 114-120 Custom converter ++++++++++++++++ We reuse some pieces of code from :ref:`l-custom-model`. The shape calculator defines the shape of every output of the converted model. .. GENERATED FROM PYTHON SOURCE LINES 120-137 .. code-block:: Python def validator_classifier_shape_calculator(operator): input0 = operator.inputs[0] # first input in ONNX graph outputs = operator.outputs # outputs in ONNX graph op = operator.raw_operator # scikit-learn model (mmust be fitted) if len(outputs) != 3: raise RuntimeError("3 outputs expected not {}.".format(len(outputs))) N = input0.type.shape[0] # number of observations C = op.estimator_.classes_.shape[0] # dimension of outputs outputs[0].type = Int64TensorType([N]) # label outputs[1].type = FloatTensorType([N, C]) # probabilities outputs[2].type = Int64TensorType([C]) # validation .. GENERATED FROM PYTHON SOURCE LINES 138-139 Then the converter. .. GENERATED FROM PYTHON SOURCE LINES 139-167 .. code-block:: Python def validator_classifier_converter(scope, operator, container): input0 = operator.inputs[0] # first input in ONNX graph outputs = operator.outputs # outputs in ONNX graph op = operator.raw_operator # scikit-learn model (mmust be fitted) opv = container.target_opset # The model calls another one. The class `OnnxSubEstimator` # calls the converter for this operator. model = op.estimator_ onnx_op = OnnxSubEstimator(model, input0, op_version=opv, options={"zipmap": False}) rmax = OnnxReduceMaxApi18(onnx_op[1], axes=[1], keepdims=0, op_version=opv) great = OnnxGreater( rmax, np.array([op.threshold], dtype=np.float32), op_version=opv ) valid = OnnxCast(great, to=onnx_proto.TensorProto.INT64, op_version=opv) r1 = OnnxIdentity(onnx_op[0], output_names=[outputs[0].full_name], op_version=opv) r2 = OnnxIdentity(onnx_op[1], output_names=[outputs[1].full_name], op_version=opv) r3 = OnnxIdentity(valid, output_names=[outputs[2].full_name], op_version=opv) r1.add_to(scope, container) r2.add_to(scope, container) r3.add_to(scope, container) .. GENERATED FROM PYTHON SOURCE LINES 168-169 Then the registration. .. GENERATED FROM PYTHON SOURCE LINES 169-178 .. code-block:: Python update_registered_converter( ValidatorClassifier, "CustomValidatorClassifier", validator_classifier_shape_calculator, validator_classifier_converter, ) .. GENERATED FROM PYTHON SOURCE LINES 179-180 And conversion... .. GENERATED FROM PYTHON SOURCE LINES 180-186 .. code-block:: Python try: to_onnx(model, X_test[:1].astype(np.float32), target_opset=12) except RuntimeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none 3 outputs expected not 2. .. GENERATED FROM PYTHON SOURCE LINES 187-194 It fails because the library expected the model to behave like a classifier which produces two outputs. We need to add a custom parser to tell the library this model produces three outputs. Custom parser +++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 194-215 .. code-block:: Python def validator_classifier_parser(scope, model, inputs, custom_parsers=None): alias = get_model_alias(type(model)) this_operator = scope.declare_local_operator(alias, model) # inputs this_operator.inputs.append(inputs[0]) # outputs val_label = scope.declare_local_variable("val_label", Int64TensorType()) val_prob = scope.declare_local_variable("val_prob", FloatTensorType()) val_val = scope.declare_local_variable("val_val", Int64TensorType()) this_operator.outputs.append(val_label) this_operator.outputs.append(val_prob) this_operator.outputs.append(val_val) # ends return this_operator.outputs .. GENERATED FROM PYTHON SOURCE LINES 216-217 Registration. .. GENERATED FROM PYTHON SOURCE LINES 217-227 .. code-block:: Python update_registered_converter( ValidatorClassifier, "CustomValidatorClassifier", validator_classifier_shape_calculator, validator_classifier_converter, parser=validator_classifier_parser, ) .. GENERATED FROM PYTHON SOURCE LINES 228-229 And conversion again. .. GENERATED FROM PYTHON SOURCE LINES 229-232 .. code-block:: Python model_onnx = to_onnx(model, X_test[:1].astype(np.float32), target_opset=12) .. GENERATED FROM PYTHON SOURCE LINES 233-237 Final test ++++++++++ We need now to check the results are the same with ONNX. .. GENERATED FROM PYTHON SOURCE LINES 237-255 .. code-block:: Python X32 = X_test[:5].astype(np.float32) sess = rt.InferenceSession( model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] ) results = sess.run(None, {"X": X32}) print("--labels--") print("sklearn", model.predict(X32)) print("onnx", results[0]) print("--probabilities--") print("sklearn", model.predict_proba(X32)) print("onnx", results[1]) print("--validation--") print("sklearn", model.validate(X32)) print("onnx", results[2]) .. rst-class:: sphx-glr-script-out .. code-block:: none --labels-- sklearn [2 1 0 2 2] onnx [2 1 0 2 2] --probabilities-- sklearn [[4.35572853e-04 2.50370783e-01 7.49193644e-01] [6.40021595e-02 7.74525152e-01 1.61472688e-01] [8.72966069e-01 1.27016600e-01 1.73305757e-05] [2.88656526e-03 4.13689781e-01 5.83423654e-01] [6.84848807e-04 4.18572039e-01 5.80743112e-01]] onnx [[4.3557768e-04 2.5037074e-01 7.4919367e-01] [6.4002089e-02 7.7452540e-01 1.6147259e-01] [8.7296611e-01 1.2701656e-01 1.7331236e-05] [2.8865638e-03 4.1368976e-01 5.8342361e-01] [6.8487308e-04 4.1857198e-01 5.8074319e-01]] --validation-- sklearn [0 1 1 0 0] onnx [0 1 1 0 0] .. GENERATED FROM PYTHON SOURCE LINES 256-260 It looks good. Display the ONNX graph ++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 260-278 .. code-block:: Python pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("validator_classifier.dot") os.system("dot -O -Gdpi=300 -Tpng validator_classifier.dot") image = plt.imread("validator_classifier.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_custom_parser_alternative_001.png :alt: plot custom parser alternative :srcset: /auto_examples/images/sphx_glr_plot_custom_parser_alternative_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (np.float64(-0.5), np.float64(3557.5), np.float64(4934.5), np.float64(-0.5)) .. GENERATED FROM PYTHON SOURCE LINES 279-280 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 280-286 .. code-block:: Python print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", rt.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 2.2.0 scikit-learn: 1.6.0 onnx: 1.18.0 onnxruntime: 1.21.0+cu126 skl2onnx: 1.18.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.343 seconds) .. _sphx_glr_download_auto_examples_plot_custom_parser_alternative.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_custom_parser_alternative.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_custom_parser_alternative.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_custom_parser_alternative.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_