.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_complex_pipeline.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_complex_pipeline.py: .. _example-complex-pipeline: Convert a pipeline with ColumnTransformer ========================================= *scikit-learn* recently shipped `ColumnTransformer `_ which lets the user define complex pipeline where each column may be preprocessed with a different transformer. *sklearn-onnx* still works in this case as shown in Section :ref:`l-complex-pipeline`. Create and train a complex pipeline +++++++++++++++++++++++++++++++++++ We reuse the pipeline implemented in example `Column Transformer with Mixed Types `_. There is one change because `ONNX-ML Imputer `_ does not handle string type. This cannot be part of the final ONNX pipeline and must be removed. Look for comment starting with ``---`` below. .. GENERATED FROM PYTHON SOURCE LINES 32-101 .. code-block:: default import os import pprint import pandas as pd import numpy as np from numpy.testing import assert_almost_equal import onnx from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnxruntime as rt import matplotlib.pyplot as plt import sklearn from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler, OneHotEncoder from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split import skl2onnx from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType, StringTensorType from skl2onnx.common.data_types import Int64TensorType titanic_url = ( "https://raw.githubusercontent.com/amueller/" "scipy-2017-sklearn/091d371/notebooks/datasets/titanic3.csv" ) data = pd.read_csv(titanic_url) X = data.drop("survived", axis=1) y = data["survived"] print(data.dtypes) # SimpleImputer on string is not available for # string in ONNX-ML specifications. # So we do it beforehand. for cat in ["embarked", "sex", "pclass"]: X[cat].fillna("missing", inplace=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) numeric_features = ["age", "fare"] numeric_transformer = Pipeline( steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())] ) categorical_features = ["embarked", "sex", "pclass"] categorical_transformer = Pipeline( steps=[ # --- SimpleImputer is not available for strings in ONNX-ML specifications. # ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), ("onehot", OneHotEncoder(handle_unknown="ignore")) ] ) preprocessor = ColumnTransformer( transformers=[ ("num", numeric_transformer, numeric_features), ("cat", categorical_transformer, categorical_features), ] ) clf = Pipeline( steps=[ ("preprocessor", preprocessor), ("classifier", LogisticRegression(solver="lbfgs")), ] ) clf.fit(X_train, y_train) .. rst-class:: sphx-glr-script-out .. code-block:: none pclass int64 survived int64 name object sex object age float64 sibsp int64 parch int64 ticket object fare float64 cabin object embarked object boat object body float64 home.dest object dtype: object .. raw:: html
Pipeline(steps=[('preprocessor',
                     ColumnTransformer(transformers=[('num',
                                                      Pipeline(steps=[('imputer',
                                                                       SimpleImputer(strategy='median')),
                                                                      ('scaler',
                                                                       StandardScaler())]),
                                                      ['age', 'fare']),
                                                     ('cat',
                                                      Pipeline(steps=[('onehot',
                                                                       OneHotEncoder(handle_unknown='ignore'))]),
                                                      ['embarked', 'sex',
                                                       'pclass'])])),
                    ('classifier', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 102-108 Define the inputs of the ONNX graph +++++++++++++++++++++++++++++++++++ *sklearn-onnx* does not know the features used to train the model but it needs to know which feature has which name. We simply reuse the dataframe column definition. .. GENERATED FROM PYTHON SOURCE LINES 108-110 .. code-block:: default print(X_train.dtypes) .. rst-class:: sphx-glr-script-out .. code-block:: none pclass int64 name object sex object age float64 sibsp int64 parch int64 ticket object fare float64 cabin object embarked object boat object body float64 home.dest object dtype: object .. GENERATED FROM PYTHON SOURCE LINES 111-112 After conversion. .. GENERATED FROM PYTHON SOURCE LINES 112-133 .. code-block:: default def convert_dataframe_schema(df, drop=None): inputs = [] for k, v in zip(df.columns, df.dtypes): if drop is not None and k in drop: continue if v == "int64": t = Int64TensorType([None, 1]) elif v == "float64": t = FloatTensorType([None, 1]) else: t = StringTensorType([None, 1]) inputs.append((k, t)) return inputs initial_inputs = convert_dataframe_schema(X_train) pprint.pprint(initial_inputs) .. rst-class:: sphx-glr-script-out .. code-block:: none [('pclass', Int64TensorType(shape=[None, 1])), ('name', StringTensorType(shape=[None, 1])), ('sex', StringTensorType(shape=[None, 1])), ('age', FloatTensorType(shape=[None, 1])), ('sibsp', Int64TensorType(shape=[None, 1])), ('parch', Int64TensorType(shape=[None, 1])), ('ticket', StringTensorType(shape=[None, 1])), ('fare', FloatTensorType(shape=[None, 1])), ('cabin', StringTensorType(shape=[None, 1])), ('embarked', StringTensorType(shape=[None, 1])), ('boat', StringTensorType(shape=[None, 1])), ('body', FloatTensorType(shape=[None, 1])), ('home.dest', StringTensorType(shape=[None, 1]))] .. GENERATED FROM PYTHON SOURCE LINES 134-137 Merging single column into vectors is not the most efficient way to compute the prediction. It could be done before converting the pipeline into a graph. .. GENERATED FROM PYTHON SOURCE LINES 139-141 Convert the pipeline into ONNX ++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 141-149 .. code-block:: default try: model_onnx = convert_sklearn( clf, "pipeline_titanic", initial_inputs, target_opset=12 ) except Exception as e: print(e) .. GENERATED FROM PYTHON SOURCE LINES 150-153 Predictions are more efficient if the graph is small. That's why the converter checks that there is no unused input. They need to be removed from the graph inputs. .. GENERATED FROM PYTHON SOURCE LINES 153-163 .. code-block:: default to_drop = {"parch", "sibsp", "cabin", "ticket", "name", "body", "home.dest", "boat"} initial_inputs = convert_dataframe_schema(X_train, to_drop) try: model_onnx = convert_sklearn( clf, "pipeline_titanic", initial_inputs, target_opset=12 ) except Exception as e: print(e) .. GENERATED FROM PYTHON SOURCE LINES 164-167 *scikit-learn* does implicit conversions when it can. *sklearn-onnx* does not. The ONNX version of *OneHotEncoder* must be applied on columns of the same type. .. GENERATED FROM PYTHON SOURCE LINES 167-177 .. code-block:: default initial_inputs = convert_dataframe_schema(X_train, to_drop) model_onnx = convert_sklearn(clf, "pipeline_titanic", initial_inputs, target_opset=12) # And save. with open("pipeline_titanic.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 178-184 Compare the predictions +++++++++++++++++++++++ Final step, we need to ensure the converted model produces the same predictions, labels and probabilities. Let's start with *scikit-learn*. .. GENERATED FROM PYTHON SOURCE LINES 184-188 .. code-block:: default print("predict", clf.predict(X_test[:5])) print("predict_proba", clf.predict_proba(X_test[:2])) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 0 0 1 0] predict_proba [[0.34887729 0.65112271] [0.92762608 0.07237392]] .. GENERATED FROM PYTHON SOURCE LINES 189-198 Predictions with onnxruntime. We need to remove the dropped columns and to change the double vectors into float vectors as *onnxruntime* does not support double floats. *onnxruntime* does not accept *dataframe*. inputs must be given as a list of dictionary. Last detail, every column was described not really as a vector but as a matrix of one column which explains the last line with the *reshape*. .. GENERATED FROM PYTHON SOURCE LINES 198-206 .. code-block:: default X_test2 = X_test.drop(to_drop, axis=1) inputs = {c: X_test2[c].values for c in X_test2.columns} for c in numeric_features: inputs[c] = inputs[c].astype(np.float32) for k in inputs: inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) .. GENERATED FROM PYTHON SOURCE LINES 207-208 We are ready to run *onnxruntime*. .. GENERATED FROM PYTHON SOURCE LINES 208-214 .. code-block:: default sess = rt.InferenceSession("pipeline_titanic.onnx", providers=["CPUExecutionProvider"]) pred_onx = sess.run(None, inputs) print("predict", pred_onx[0][:5]) print("predict_proba", pred_onx[1][:2]) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 0 0 1 0] predict_proba [{0: 0.34887731075286865, 1: 0.6511226892471313}, {0: 0.9276261329650879, 1: 0.0723738968372345}] .. GENERATED FROM PYTHON SOURCE LINES 215-218 The output of onnxruntime is a list of dictionaries. Let's swith to an array but that requires to convert again with an additional option zipmap. .. GENERATED FROM PYTHON SOURCE LINES 218-236 .. code-block:: default model_onnx = convert_sklearn( clf, "pipeline_titanic", initial_inputs, target_opset=12, options={id(clf): {"zipmap": False}}, ) with open("pipeline_titanic_nozipmap.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) sess = rt.InferenceSession( "pipeline_titanic_nozipmap.onnx", providers=["CPUExecutionProvider"] ) pred_onx = sess.run(None, inputs) print("predict", pred_onx[0][:5]) print("predict_proba", pred_onx[1][:2]) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 0 0 1 0] predict_proba [[0.3488773 0.6511227 ] [0.92762613 0.0723739 ]] .. GENERATED FROM PYTHON SOURCE LINES 237-238 Let's check they are the same. .. GENERATED FROM PYTHON SOURCE LINES 238-240 .. code-block:: default assert_almost_equal(clf.predict_proba(X_test), pred_onx[1]) .. GENERATED FROM PYTHON SOURCE LINES 241-247 .. _l-plot-complex-pipeline-graph: Display the ONNX graph ++++++++++++++++++++++ Finally, let's see the graph converted with *sklearn-onnx*. .. GENERATED FROM PYTHON SOURCE LINES 247-265 .. code-block:: default pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("pipeline_titanic.dot") os.system("dot -O -Gdpi=300 -Tpng pipeline_titanic.dot") image = plt.imread("pipeline_titanic.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_complex_pipeline_001.png :alt: plot complex pipeline :srcset: /auto_examples/images/sphx_glr_plot_complex_pipeline_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 6901.5, 6049.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 266-267 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 267-273 .. code-block:: default print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", rt.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 1.23.5 scikit-learn: 1.4.dev0 onnx: 1.15.0 onnxruntime: 1.16.0+cu118 skl2onnx: 1.15.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.095 seconds) .. _sphx_glr_download_auto_examples_plot_complex_pipeline.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_complex_pipeline.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_complex_pipeline.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_