.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorial/plot_dbegin_options.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorial_plot_dbegin_options.py: One model, many possible conversions with options ================================================= .. index:: options There is not one way to convert a model. A new operator might have been added in a newer version of :epkg:`ONNX` and that speeds up the converted model. The rational choice would be to use this new operator but what means the associated runtime has an implementation for it. What if two different users needs two different conversion for the same model? Let's see how this may be done. Option *zipmap* +++++++++++++++ Every classifier is by design converted into an ONNX graph which outputs two results: the predicted label and the prediction probabilites for every label. By default, the labels are integers and the probabilites are stored in dictionaries. That's the purpose of operator *ZipMap* added at the end of the following graph. .. runpython:: import numpy from onnx.helper import printable_graph from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from skl2onnx import to_onnx iris = load_iris() X, y = iris.data, iris.target X_train, _, y_train, __ = train_test_split(X, y, random_state=11) clr = LogisticRegression(max_iter=1000) clr.fit(X_train, y_train) model_def = to_onnx(clr, X_train.astype(numpy.float32)) print(printable_graph(model_def.graph)) This operator is not really efficient as it copies every probabilies and labels in a different container. This time is usually significant for small classifiers. Then it makes sense to remove it. .. runpython:: import numpy from onnx.helper import printable_graph from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from skl2onnx import to_onnx iris = load_iris() X, y = iris.data, iris.target X_train, _, y_train, __ = train_test_split(X, y, random_state=11) clr = LogisticRegression(max_iter=1000) clr.fit(X_train, y_train) model_def = to_onnx(clr, X_train.astype(numpy.float32), options={LogisticRegression: {'zipmap': False}}) print(printable_graph(model_def.graph)) There might be in the graph many classifiers, it is important to have a way to specify which classifier should keep its *ZipMap* and which is not. So it is possible to specify options by id. .. GENERATED FROM PYTHON SOURCE LINES 70-97 .. code-block:: default from pprint import pformat import numpy from onnx.reference import ReferenceEvaluator from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import MinMaxScaler from sklearn.pipeline import Pipeline from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from skl2onnx.common._registration import _converter_pool from skl2onnx import to_onnx from onnxruntime import InferenceSession iris = load_iris() X, y = iris.data, iris.target X_train, X_test, y_train, _ = train_test_split(X, y, random_state=11) clr = LogisticRegression() clr.fit(X_train, y_train) model_def = to_onnx( clr, X_train.astype(numpy.float32), options={id(clr): {"zipmap": False}} ) oinf = ReferenceEvaluator(model_def) print(oinf) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/xadupre/github/scikit-learn/sklearn/linear_model/_logistic.py:472: ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. of ITERATIONS REACHED LIMIT. Increase the number of iterations (max_iter) or scale the data as shown in: https://scikit-learn.org/stable/modules/preprocessing.html Please also refer to the documentation for alternative solver options: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression n_iter_i = _check_optimize_result( ReferenceEvaluator(X) -> label, probabilities .. GENERATED FROM PYTHON SOURCE LINES 98-100 Using function *id* has one flaw: it is not pickable. It is just better to use strings. .. GENERATED FROM PYTHON SOURCE LINES 100-106 .. code-block:: default model_def = to_onnx(clr, X_train.astype(numpy.float32), options={"zipmap": False}) oinf = ReferenceEvaluator(model_def) print(oinf) .. rst-class:: sphx-glr-script-out .. code-block:: none ReferenceEvaluator(X) -> label, probabilities .. GENERATED FROM PYTHON SOURCE LINES 107-112 Option in a pipeline ++++++++++++++++++++ In a pipeline, :epkg:`sklearn-onnx` uses the same name convention. .. GENERATED FROM PYTHON SOURCE LINES 112-121 .. code-block:: default pipe = Pipeline([("norm", MinMaxScaler()), ("clr", LogisticRegression())]) pipe.fit(X_train, y_train) model_def = to_onnx(pipe, X_train.astype(numpy.float32), options={"clr__zipmap": False}) oinf = ReferenceEvaluator(model_def) print(oinf) .. rst-class:: sphx-glr-script-out .. code-block:: none ReferenceEvaluator(X) -> label, probabilities .. GENERATED FROM PYTHON SOURCE LINES 122-129 Option *raw_scores* +++++++++++++++++++ Every classifier is converted in a graph which returns probabilities by default. But many models compute unscaled *raw_scores*. First, with probabilities: .. GENERATED FROM PYTHON SOURCE LINES 129-142 .. code-block:: default pipe = Pipeline([("norm", MinMaxScaler()), ("clr", LogisticRegression())]) pipe.fit(X_train, y_train) model_def = to_onnx( pipe, X_train.astype(numpy.float32), options={id(pipe): {"zipmap": False}} ) oinf = ReferenceEvaluator(model_def) print(oinf.run(None, {"X": X.astype(numpy.float32)[:5]})) .. rst-class:: sphx-glr-script-out .. code-block:: none [array([0, 0, 0, 0, 0]), array([[0.88268626, 0.10948393, 0.00782984], [0.7944385 , 0.19728662, 0.00827491], [0.85557765, 0.13792053, 0.00650185], [0.8262804 , 0.16634221, 0.00737737], [0.90050155, 0.092388 , 0.00711049]], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 143-144 Then with raw scores: .. GENERATED FROM PYTHON SOURCE LINES 144-154 .. code-block:: default model_def = to_onnx( pipe, X_train.astype(numpy.float32), options={id(pipe): {"raw_scores": True, "zipmap": False}}, ) oinf = ReferenceEvaluator(model_def) print(oinf.run(None, {"X": X.astype(numpy.float32)[:5]})) .. rst-class:: sphx-glr-script-out .. code-block:: none [array([0, 0, 0, 0, 0]), array([[0.88268626, 0.10948393, 0.00782984], [0.7944385 , 0.19728662, 0.00827491], [0.85557765, 0.13792053, 0.00650185], [0.8262804 , 0.16634221, 0.00737737], [0.90050155, 0.092388 , 0.00711049]], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 155-158 It did not seem to work... We need to tell that applies on a specific part of the pipeline and not the whole pipeline. .. GENERATED FROM PYTHON SOURCE LINES 158-168 .. code-block:: default model_def = to_onnx( pipe, X_train.astype(numpy.float32), options={id(pipe.steps[1][1]): {"raw_scores": True, "zipmap": False}}, ) oinf = ReferenceEvaluator(model_def) print(oinf.run(None, {"X": X.astype(numpy.float32)[:5]})) .. rst-class:: sphx-glr-script-out .. code-block:: none [array([0, 0, 0, 0, 0]), array([[ 2.2707398 , 0.18354762, -2.4542873 ], [ 1.9857951 , 0.5928172 , -2.5786123 ], [ 2.2349296 , 0.4098304 , -2.6447601 ], [ 2.1071343 , 0.5042473 , -2.6113818 ], [ 2.3727787 , 0.095824 , -2.4686027 ]], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 169-171 There are negative values. That works. Strings are still easier to use. .. GENERATED FROM PYTHON SOURCE LINES 171-182 .. code-block:: default model_def = to_onnx( pipe, X_train.astype(numpy.float32), options={"clr__raw_scores": True, "clr__zipmap": False}, ) oinf = ReferenceEvaluator(model_def) print(oinf.run(None, {"X": X.astype(numpy.float32)[:5]})) .. rst-class:: sphx-glr-script-out .. code-block:: none [array([0, 0, 0, 0, 0]), array([[ 2.2707398 , 0.18354762, -2.4542873 ], [ 1.9857951 , 0.5928172 , -2.5786123 ], [ 2.2349296 , 0.4098304 , -2.6447601 ], [ 2.1071343 , 0.5042473 , -2.6113818 ], [ 2.3727787 , 0.095824 , -2.4686027 ]], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 183-184 Negative figures. We still have raw scores. .. GENERATED FROM PYTHON SOURCE LINES 186-191 Option *decision_path* ++++++++++++++++++++++ *scikit-learn* implements a function to retrieve the decision path. It can be enabled by option *decision_path*. .. GENERATED FROM PYTHON SOURCE LINES 191-207 .. code-block:: default clrrf = RandomForestClassifier(n_estimators=2, max_depth=2) clrrf.fit(X_train, y_train) clrrf.predict(X_test[:2]) paths, n_nodes_ptr = clrrf.decision_path(X_test[:2]) print(paths.todense()) model_def = to_onnx( clrrf, X_train.astype(numpy.float32), options={id(clrrf): {"decision_path": True, "zipmap": False}}, ) sess = InferenceSession( model_def.SerializeToString(), providers=["CPUExecutionProvider"] ) .. rst-class:: sphx-glr-script-out .. code-block:: none [[1 0 0 0 1 0 1 1 1 0 1 0 0 0] [1 0 0 0 1 0 1 1 1 0 1 0 0 0]] .. GENERATED FROM PYTHON SOURCE LINES 208-209 The model produces 3 outputs. .. GENERATED FROM PYTHON SOURCE LINES 209-212 .. code-block:: default print([o.name for o in sess.get_outputs()]) .. rst-class:: sphx-glr-script-out .. code-block:: none ['label', 'probabilities', 'decision_path'] .. GENERATED FROM PYTHON SOURCE LINES 213-214 Let's display the last one. .. GENERATED FROM PYTHON SOURCE LINES 214-218 .. code-block:: default res = sess.run(None, {"X": X_test[:2].astype(numpy.float32)}) print(res[-1]) .. rst-class:: sphx-glr-script-out .. code-block:: none [['1000101' '1101000'] ['1000101' '1101000']] .. GENERATED FROM PYTHON SOURCE LINES 219-224 List of available options +++++++++++++++++++++++++ Options are registered for every converted to detect any supported options while running the conversion. .. GENERATED FROM PYTHON SOURCE LINES 224-237 .. code-block:: default all_opts = set() for k, v in sorted(_converter_pool.items()): opts = v.get_allowed_options() if not isinstance(opts, dict): continue name = k.replace("Sklearn", "") print("%s%s %r" % (name, " " * (30 - len(name)), opts)) for o in opts: all_opts.add(o) print("all options:", pformat(list(sorted(all_opts)))) .. rst-class:: sphx-glr-script-out .. code-block:: none AdaBoostClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} BaggingClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} BayesianGaussianMixture {'score_samples': [True, False]} BayesianRidge {'return_std': [True, False]} BernoulliNB {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} CalibratedClassifierCV {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} CategoricalNB {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} ComplementNB {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} CountVectorizer {'tokenexp': None, 'separators': None, 'nan': [True, False], 'keep_empty_string': [True, False], 'locale': None} DecisionTreeClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'decision_path': [True, False], 'decision_leaf': [True, False]} DecisionTreeRegressor {'decision_path': [True, False], 'decision_leaf': [True, False]} ExtraTreeClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'decision_path': [True, False], 'decision_leaf': [True, False]} ExtraTreeRegressor {'decision_path': [True, False], 'decision_leaf': [True, False]} ExtraTreesClassifier {'zipmap': [True, False, 'columns'], 'raw_scores': [True, False], 'nocl': [True, False], 'output_class_labels': [False, True], 'decision_path': [True, False], 'decision_leaf': [True, False]} ExtraTreesRegressor {'decision_path': [True, False], 'decision_leaf': [True, False]} GaussianMixture {'score_samples': [True, False]} GaussianNB {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} GaussianProcessClassifier {'optim': [None, 'cdist'], 'nocl': [False, True], 'output_class_labels': [False, True], 'zipmap': [False, True]} GaussianProcessRegressor {'return_cov': [False, True], 'return_std': [False, True], 'optim': [None, 'cdist']} GradientBoostingClassifier {'zipmap': [True, False, 'columns'], 'raw_scores': [True, False], 'output_class_labels': [False, True], 'nocl': [True, False]} HistGradientBoostingClassifier {'zipmap': [True, False, 'columns'], 'raw_scores': [True, False], 'output_class_labels': [False, True], 'nocl': [True, False]} HistGradientBoostingRegressor {'zipmap': [True, False, 'columns'], 'raw_scores': [True, False], 'output_class_labels': [False, True], 'nocl': [True, False]} IsolationForest {'score_samples': [True, False]} KMeans {'gemm': [True, False]} KNNImputer {'optim': [None, 'cdist']} KNeighborsClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'raw_scores': [True, False], 'output_class_labels': [False, True], 'optim': [None, 'cdist']} KNeighborsRegressor {'optim': [None, 'cdist']} KNeighborsTransformer {'optim': [None, 'cdist']} KernelPCA {'optim': [None, 'cdist']} LinearClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} LinearSVC {'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} LocalOutlierFactor {'score_samples': [True, False], 'optim': [None, 'cdist']} MLPClassifier {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} MaxAbsScaler {'div': ['std', 'div', 'div_cast']} MiniBatchKMeans {'gemm': [True, False]} MultiOutputClassifier {'nocl': [False, True], 'output_class_labels': [False, True], 'zipmap': [False, True]} MultinomialNB {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} NearestNeighbors {'optim': [None, 'cdist']} OneVsOneClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True]} OneVsRestClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} QuadraticDiscriminantAnalysis {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True]} RadiusNeighborsClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'raw_scores': [True, False], 'output_class_labels': [False, True], 'optim': [None, 'cdist']} RadiusNeighborsRegressor {'optim': [None, 'cdist']} RandomForestClassifier {'zipmap': [True, False, 'columns'], 'raw_scores': [True, False], 'nocl': [True, False], 'output_class_labels': [False, True], 'decision_path': [True, False], 'decision_leaf': [True, False]} RandomForestRegressor {'decision_path': [True, False], 'decision_leaf': [True, False]} RobustScaler {'div': ['std', 'div', 'div_cast']} SGDClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} SVC {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} Scaler {'div': ['std', 'div', 'div_cast']} StackingClassifier {'zipmap': [True, False, 'columns'], 'nocl': [True, False], 'output_class_labels': [False, True], 'raw_scores': [True, False]} TfidfTransformer {'nan': [True, False]} TfidfVectorizer {'tokenexp': None, 'separators': None, 'nan': [True, False], 'keep_empty_string': [True, False], 'locale': None} VotingClassifier {'zipmap': [True, False, 'columns'], 'output_class_labels': [False, True], 'nocl': [True, False]} _ConstantPredictor {'zipmap': [True, False, 'columns'], 'nocl': [True, False]} all options: ['decision_leaf', 'decision_path', 'div', 'gemm', 'keep_empty_string', 'locale', 'nan', 'nocl', 'optim', 'output_class_labels', 'raw_scores', 'return_cov', 'return_std', 'score_samples', 'separators', 'tokenexp', 'zipmap'] .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.116 seconds) .. _sphx_glr_download_auto_tutorial_plot_dbegin_options.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_dbegin_options.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_dbegin_options.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_