.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_intermediate_outputs.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_intermediate_outputs.py: Walk through intermediate outputs ================================= We reuse the example :ref:`example-complex-pipeline` and walk through intermediates outputs. It is very likely a converted model gives different outputs or fails due to a custom converter which is not correctly implemented. One option is to look into the output of every node of the ONNX graph. Create and train a complex pipeline +++++++++++++++++++++++++++++++++++ We reuse the pipeline implemented in example `Column Transformer with Mixed Types `_. There is one change because `ONNX-ML Imputer `_ does not handle string type. This cannot be part of the final ONNX pipeline and must be removed. Look for comment starting with ``---`` below. .. GENERATED FROM PYTHON SOURCE LINES 28-102 .. code-block:: default import skl2onnx import onnx import sklearn import matplotlib.pyplot as plt import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer from skl2onnx.helpers.onnx_helper import select_model_inputs_outputs from skl2onnx.helpers.onnx_helper import save_onnx_model from skl2onnx.helpers.onnx_helper import enumerate_model_node_outputs from skl2onnx.helpers.onnx_helper import load_onnx_model import numpy import onnxruntime as rt from skl2onnx import convert_sklearn import pprint from skl2onnx.common.data_types import ( FloatTensorType, StringTensorType, Int64TensorType, ) import numpy as np import pandas as pd from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler, OneHotEncoder from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split titanic_url = ( "https://raw.githubusercontent.com/amueller/" "scipy-2017-sklearn/091d371/notebooks/datasets/titanic3.csv" ) data = pd.read_csv(titanic_url) X = data.drop("survived", axis=1) y = data["survived"] # SimpleImputer on string is not available # for string in ONNX-ML specifications. # So we do it beforehand. for cat in ["embarked", "sex", "pclass"]: X[cat].fillna("missing", inplace=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) numeric_features = ["age", "fare"] numeric_transformer = Pipeline( steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())] ) categorical_features = ["embarked", "sex", "pclass"] categorical_transformer = Pipeline( steps=[ # --- SimpleImputer is not available for strings in ONNX-ML specifications. # ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), ("onehot", OneHotEncoder(handle_unknown="ignore")) ] ) preprocessor = ColumnTransformer( transformers=[ ("num", numeric_transformer, numeric_features), ("cat", categorical_transformer, categorical_features), ] ) clf = Pipeline( steps=[ ("preprocessor", preprocessor), ("classifier", LogisticRegression(solver="lbfgs")), ] ) clf.fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('preprocessor',
                     ColumnTransformer(transformers=[('num',
                                                      Pipeline(steps=[('imputer',
                                                                       SimpleImputer(strategy='median')),
                                                                      ('scaler',
                                                                       StandardScaler())]),
                                                      ['age', 'fare']),
                                                     ('cat',
                                                      Pipeline(steps=[('onehot',
                                                                       OneHotEncoder(handle_unknown='ignore'))]),
                                                      ['embarked', 'sex',
                                                       'pclass'])])),
                    ('classifier', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 103-109 Define the inputs of the ONNX graph +++++++++++++++++++++++++++++++++++ *sklearn-onnx* does not know the features used to train the model but it needs to know which feature has which name. We simply reuse the dataframe column definition. .. GENERATED FROM PYTHON SOURCE LINES 109-111 .. code-block:: default print(X_train.dtypes) .. rst-class:: sphx-glr-script-out .. code-block:: none pclass int64 name object sex object age float64 sibsp int64 parch int64 ticket object fare float64 cabin object embarked object boat object body float64 home.dest object dtype: object .. GENERATED FROM PYTHON SOURCE LINES 112-113 After conversion. .. GENERATED FROM PYTHON SOURCE LINES 113-134 .. code-block:: default def convert_dataframe_schema(df, drop=None): inputs = [] for k, v in zip(df.columns, df.dtypes): if drop is not None and k in drop: continue if v == "int64": t = Int64TensorType([None, 1]) elif v == "float64": t = FloatTensorType([None, 1]) else: t = StringTensorType([None, 1]) inputs.append((k, t)) return inputs inputs = convert_dataframe_schema(X_train) pprint.pprint(inputs) .. rst-class:: sphx-glr-script-out .. code-block:: none [('pclass', Int64TensorType(shape=[None, 1])), ('name', StringTensorType(shape=[None, 1])), ('sex', StringTensorType(shape=[None, 1])), ('age', FloatTensorType(shape=[None, 1])), ('sibsp', Int64TensorType(shape=[None, 1])), ('parch', Int64TensorType(shape=[None, 1])), ('ticket', StringTensorType(shape=[None, 1])), ('fare', FloatTensorType(shape=[None, 1])), ('cabin', StringTensorType(shape=[None, 1])), ('embarked', StringTensorType(shape=[None, 1])), ('boat', StringTensorType(shape=[None, 1])), ('body', FloatTensorType(shape=[None, 1])), ('home.dest', StringTensorType(shape=[None, 1]))] .. GENERATED FROM PYTHON SOURCE LINES 135-138 Merging single column into vectors is not the most efficient way to compute the prediction. It could be done before converting the pipeline into a graph. .. GENERATED FROM PYTHON SOURCE LINES 140-142 Convert the pipeline into ONNX ++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 142-148 .. code-block:: default try: model_onnx = convert_sklearn(clf, "pipeline_titanic", inputs, target_opset=12) except Exception as e: print(e) .. GENERATED FROM PYTHON SOURCE LINES 149-152 *scikit-learn* does implicit conversions when it can. *sklearn-onnx* does not. The ONNX version of *OneHotEncoder* must be applied on columns of the same type. .. GENERATED FROM PYTHON SOURCE LINES 152-166 .. code-block:: default X_train["pclass"] = X_train["pclass"].astype(str) X_test["pclass"] = X_test["pclass"].astype(str) white_list = numeric_features + categorical_features to_drop = [c for c in X_train.columns if c not in white_list] inputs = convert_dataframe_schema(X_train, to_drop) model_onnx = convert_sklearn(clf, "pipeline_titanic", inputs, target_opset=12) # And save. with open("pipeline_titanic.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 167-173 Compare the predictions +++++++++++++++++++++++ Final step, we need to ensure the converted model produces the same predictions, labels and probabilities. Let's start with *scikit-learn*. .. GENERATED FROM PYTHON SOURCE LINES 173-177 .. code-block:: default print("predict", clf.predict(X_test[:5])) print("predict_proba", clf.predict_proba(X_test[:1])) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 1 0 0 1] predict_proba [[0.18274774 0.81725226]] .. GENERATED FROM PYTHON SOURCE LINES 178-187 Predictions with onnxruntime. We need to remove the dropped columns and to change the double vectors into float vectors as *onnxruntime* does not support double floats. *onnxruntime* does not accept *dataframe*. inputs must be given as a list of dictionary. Last detail, every column was described not really as a vector but as a matrix of one column which explains the last line with the *reshape*. .. GENERATED FROM PYTHON SOURCE LINES 187-195 .. code-block:: default X_test2 = X_test.drop(to_drop, axis=1) inputs = {c: X_test2[c].values for c in X_test2.columns} for c in numeric_features: inputs[c] = inputs[c].astype(np.float32) for k in inputs: inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1)) .. GENERATED FROM PYTHON SOURCE LINES 196-197 We are ready to run *onnxruntime*. .. GENERATED FROM PYTHON SOURCE LINES 197-204 .. code-block:: default sess = rt.InferenceSession("pipeline_titanic.onnx", providers=["CPUExecutionProvider"]) pred_onx = sess.run(None, inputs) print("predict", pred_onx[0][:5]) print("predict_proba", pred_onx[1][:1]) .. rst-class:: sphx-glr-script-out .. code-block:: none predict [1 1 0 0 1] predict_proba [{0: 0.3934966027736664, 1: 0.6065033674240112}] .. GENERATED FROM PYTHON SOURCE LINES 205-212 Compute intermediate outputs ++++++++++++++++++++++++++++ Unfortunately, there is actually no way to ask *onnxruntime* to retrieve the output of intermediate nodes. We need to modifies the *ONNX* before it is given to *onnxruntime*. Let's see first the list of intermediate output. .. GENERATED FROM PYTHON SOURCE LINES 212-217 .. code-block:: default model_onnx = load_onnx_model("pipeline_titanic.onnx") for out in enumerate_model_node_outputs(model_onnx): print(out) .. rst-class:: sphx-glr-script-out .. code-block:: none merged_columns embarkedout sexout pclassout concat_result variable variable2 variable1 transformed_column label probability_tensor output_label probabilities output_probability .. GENERATED FROM PYTHON SOURCE LINES 218-224 Not that easy to tell which one is what as the *ONNX* has more operators than the original *scikit-learn* pipelines. The graph at :ref:`l-plot-complex-pipeline-graph` helps up to find the outputs of both numerical and textual pipeline: *variable1*, *variable2*. Let's look into the numerical pipeline first. .. GENERATED FROM PYTHON SOURCE LINES 224-228 .. code-block:: default num_onnx = select_model_inputs_outputs(model_onnx, "variable1") save_onnx_model(num_onnx, "pipeline_titanic_numerical.onnx") .. rst-class:: sphx-glr-script-out .. code-block:: none b'\x08\x07\x12\x08skl2onnx\x1a\x061.16.0"\x07ai.onnx(\x002\x00:\xcd\x03\n:\n\x03age\n\x04fare\x12\x0emerged_columns\x1a\x06Concat"\x06Concat*\x0b\n\x04axis\x18\x01\xa0\x01\x02:\x00\n}\n\x0emerged_columns\x12\x08variable\x1a\x07Imputer"\x07Imputer*#\n\x14imputed_value_floats=\x00\x00\xe0A=\x00\x00`A\xa0\x01\x06*\x1e\n\x14replaced_value_float\x15\x00\x00\xc0\x7f\xa0\x01\x01:\nai.onnx.ml\n^\n\x08variable\x12\tvariable1\x1a\x06Scaler"\x06Scaler*\x15\n\x06offset=q_\xebA=\xc3\x08\x05B\xa0\x01\x06*\x14\n\x05scale=%(\x9f==\x94v\x9d<\xa0\x01\x06:\nai.onnx.ml\x12\x10pipeline_titanic*\x1f\x08\x02\x10\x07:\x0b\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\tB\x0cshape_tensorZ\x16\n\x06pclass\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01Z\x13\n\x03sex\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01Z\x13\n\x03age\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\x01Z\x14\n\x04fare\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\x01Z\x18\n\x08embarked\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01b\x0b\n\tvariable1B\x04\n\x00\x10\x0bB\x0e\n\nai.onnx.ml\x10\x01' .. GENERATED FROM PYTHON SOURCE LINES 229-230 Let's compute the numerical features. .. GENERATED FROM PYTHON SOURCE LINES 230-237 .. code-block:: default sess = rt.InferenceSession( "pipeline_titanic_numerical.onnx", providers=["CPUExecutionProvider"] ) numX = sess.run(None, inputs) print("numerical features", numX[0][:1]) .. rst-class:: sphx-glr-script-out .. code-block:: none numerical features [[-0.88761026 -0.5095364 ]] .. GENERATED FROM PYTHON SOURCE LINES 238-239 We do the same for the textual features. .. GENERATED FROM PYTHON SOURCE LINES 239-249 .. code-block:: default print(model_onnx) text_onnx = select_model_inputs_outputs(model_onnx, "variable2") save_onnx_model(text_onnx, "pipeline_titanic_textual.onnx") sess = rt.InferenceSession( "pipeline_titanic_textual.onnx", providers=["CPUExecutionProvider"] ) numT = sess.run(None, inputs) print("textual features", numT[0][:1]) .. rst-class:: sphx-glr-script-out .. code-block:: none ir_version: 7 opset_import { domain: "" version: 11 } opset_import { domain: "ai.onnx.ml" version: 1 } producer_name: "skl2onnx" producer_version: "1.16.0" domain: "ai.onnx" model_version: 0 doc_string: "" graph { node { input: "age" input: "fare" output: "merged_columns" name: "Concat" op_type: "Concat" domain: "" attribute { name: "axis" type: INT i: 1 } } node { input: "embarked" output: "embarkedout" name: "OneHotEncoder" op_type: "OneHotEncoder" domain: "ai.onnx.ml" attribute { name: "cats_strings" type: STRINGS strings: "C" strings: "Q" strings: "S" strings: "missing" } attribute { name: "zeros" type: INT i: 1 } } node { input: "sex" output: "sexout" name: "OneHotEncoder1" op_type: "OneHotEncoder" domain: "ai.onnx.ml" attribute { name: "cats_strings" type: STRINGS strings: "female" strings: "male" } attribute { name: "zeros" type: INT i: 1 } } node { input: "pclass" output: "pclassout" name: "OneHotEncoder2" op_type: "OneHotEncoder" domain: "ai.onnx.ml" attribute { name: "cats_strings" type: STRINGS strings: "1" strings: "2" strings: "3" } attribute { name: "zeros" type: INT i: 1 } } node { input: "embarkedout" input: "sexout" input: "pclassout" output: "concat_result" name: "Concat1" op_type: "Concat" domain: "" attribute { name: "axis" type: INT i: 2 } } node { input: "merged_columns" output: "variable" name: "Imputer" op_type: "Imputer" domain: "ai.onnx.ml" attribute { name: "imputed_value_floats" type: FLOATS floats: 28 floats: 14 } attribute { name: "replaced_value_float" type: FLOAT f: nan } } node { input: "concat_result" input: "shape_tensor" output: "variable2" name: "Reshape" op_type: "Reshape" domain: "" } node { input: "variable" output: "variable1" name: "Scaler" op_type: "Scaler" domain: "ai.onnx.ml" attribute { name: "offset" type: FLOATS floats: 29.4216022 floats: 33.2585564 } attribute { name: "scale" type: FLOATS floats: 0.0777132884 floats: 0.0192215815 } } node { input: "variable1" input: "variable2" output: "transformed_column" name: "Concat2" op_type: "Concat" domain: "" attribute { name: "axis" type: INT i: 1 } } node { input: "transformed_column" output: "label" output: "probability_tensor" name: "LinearClassifier" op_type: "LinearClassifier" domain: "ai.onnx.ml" attribute { name: "classlabels_ints" type: INTS ints: 0 ints: 1 } attribute { name: "coefficients" type: FLOATS floats: 0.424732059 floats: 0.0460702516 floats: -0.384316146 floats: 0.348065704 floats: 0.287240565 floats: -0.251526088 floats: -1.22550428 floats: 1.22496843 floats: -1.0177319 floats: -0.0480071418 floats: 1.06520307 floats: -0.424732059 floats: -0.0460702516 floats: 0.384316146 floats: -0.348065704 floats: -0.287240565 floats: 0.251526088 floats: 1.22550428 floats: -1.22496843 floats: 1.0177319 floats: 0.0480071418 floats: -1.06520307 } attribute { name: "intercepts" type: FLOATS floats: -0.219931483 floats: 0.219931483 } attribute { name: "multi_class" type: INT i: 1 } attribute { name: "post_transform" type: STRING s: "LOGISTIC" } } node { input: "label" output: "output_label" name: "Cast" op_type: "Cast" domain: "" attribute { name: "to" type: INT i: 7 } } node { input: "probability_tensor" output: "probabilities" name: "Normalizer" op_type: "Normalizer" domain: "ai.onnx.ml" attribute { name: "norm" type: STRING s: "L1" } } node { input: "probabilities" output: "output_probability" name: "ZipMap" op_type: "ZipMap" domain: "ai.onnx.ml" attribute { name: "classlabels_int64s" type: INTS ints: 0 ints: 1 } } name: "pipeline_titanic" initializer { dims: 2 data_type: 7 int64_data: -1 int64_data: 9 name: "shape_tensor" } input { name: "pclass" type { tensor_type { elem_type: 8 shape { dim { } dim { dim_value: 1 } } } } } input { name: "sex" type { tensor_type { elem_type: 8 shape { dim { } dim { dim_value: 1 } } } } } input { name: "age" type { tensor_type { elem_type: 1 shape { dim { } dim { dim_value: 1 } } } } } input { name: "fare" type { tensor_type { elem_type: 1 shape { dim { } dim { dim_value: 1 } } } } } input { name: "embarked" type { tensor_type { elem_type: 8 shape { dim { } dim { dim_value: 1 } } } } } output { name: "output_label" type { tensor_type { elem_type: 7 shape { dim { } } } } } output { name: "output_probability" type { sequence_type { elem_type { map_type { key_type: 7 value_type { tensor_type { elem_type: 1 } } } } } } } } textual features [[0. 1. 0. 0. 1. 0. 0. 0. 1.]] .. GENERATED FROM PYTHON SOURCE LINES 250-254 Display the sub-ONNX graph ++++++++++++++++++++++++++ Finally, let's see both subgraphs. First, numerical pipeline. .. GENERATED FROM PYTHON SOURCE LINES 254-272 .. code-block:: default pydot_graph = GetPydotGraph( num_onnx.graph, name=num_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("pipeline_titanic_num.dot") os.system("dot -O -Gdpi=300 -Tpng pipeline_titanic_num.dot") image = plt.imread("pipeline_titanic_num.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_intermediate_outputs_001.png :alt: plot intermediate outputs :srcset: /auto_examples/images/sphx_glr_plot_intermediate_outputs_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 1229.5, 2558.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 273-274 Then textual pipeline. .. GENERATED FROM PYTHON SOURCE LINES 274-292 .. code-block:: default pydot_graph = GetPydotGraph( text_onnx.graph, name=text_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("pipeline_titanic_text.dot") os.system("dot -O -Gdpi=300 -Tpng pipeline_titanic_text.dot") image = plt.imread("pipeline_titanic_text.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_intermediate_outputs_002.png :alt: plot intermediate outputs :srcset: /auto_examples/images/sphx_glr_plot_intermediate_outputs_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 5630.5, 2735.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 293-294 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 294-300 .. code-block:: default print("numpy:", numpy.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", rt.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 1.23.5 scikit-learn: 1.4.dev0 onnx: 1.15.0 onnxruntime: 1.16.0+cu118 skl2onnx: 1.16.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.421 seconds) .. _sphx_glr_download_auto_examples_plot_intermediate_outputs.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_intermediate_outputs.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_intermediate_outputs.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_