.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorial/plot_fbegin_investigate.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorial_plot_fbegin_investigate.py: Intermediate results and investigation ====================================== .. index:: investigate, intermediate results There are many reasons why a user wants more than using the converted model into ONNX. Intermediate results may be needed, the output of every node in the graph. The ONNX may need to be altered to remove some nodes. Transfer learning is usually removing the last layers of a deep neural network. Another reaason is debugging. It often happens that the runtime fails to compute the predictions due to a shape mismatch. Then it is useful the get the shape of every intermediate result. This example looks into two ways of doing it. Look into pipeline steps ++++++++++++++++++++++++ The first way is a tricky one: it overloads methods *transform*, *predict* and *predict_proba* to keep a copy of inputs and outputs. It then goes through every step of the pipeline. If the pipeline has *n* steps, it converts the pipeline with step 1, then the pipeline with steps 1, 2, then 1, 2, 3... .. GENERATED FROM PYTHON SOURCE LINES 30-41 .. code-block:: default import numpy from onnx.reference import ReferenceEvaluator from onnxruntime import InferenceSession from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.cluster import KMeans from sklearn.datasets import load_iris from skl2onnx import to_onnx from skl2onnx.helpers import collect_intermediate_steps from skl2onnx.common.data_types import FloatTensorType .. GENERATED FROM PYTHON SOURCE LINES 42-43 The pipeline. .. GENERATED FROM PYTHON SOURCE LINES 43-50 .. code-block:: default data = load_iris() X = data.data pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))]) pipe.fit(X) .. raw:: html
Pipeline(steps=[('std', StandardScaler()),
                    ('km', KMeans(n_clusters=3, n_init=3))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 51-54 The function goes through every step, overloads the methods *transform* and returns an ONNX graph for every step. .. GENERATED FROM PYTHON SOURCE LINES 54-58 .. code-block:: default steps = collect_intermediate_steps( pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17 ) .. GENERATED FROM PYTHON SOURCE LINES 59-61 We call method transform to population the cache the overloaded methods *transform* keeps. .. GENERATED FROM PYTHON SOURCE LINES 61-63 .. code-block:: default pipe.transform(X) .. rst-class:: sphx-glr-script-out .. code-block:: none array([[3.12119834, 0.21295824, 3.98940603], [2.6755083 , 0.99604549, 4.01793312], [2.97416665, 0.65198444, 4.19343668], [2.88014429, 0.9034561 , 4.19784749], [3.30022609, 0.40215457, 4.11157152], [3.50554424, 1.21154793, 3.89893116], [3.14856384, 0.50244932, 4.21638048], [2.99184826, 0.09132468, 3.97313411], [2.92515933, 1.42174651, 4.40757189], [2.79398956, 0.78993078, 4.05764261], [3.32125333, 0.78999385, 3.92088109], [3.0493632 , 0.27618123, 4.07853631], [2.80635045, 1.03497888, 4.16440431], [3.21220972, 1.33482453, 4.63069748], [3.88834965, 1.63865558, 4.14619343], [4.4998303 , 2.39898792, 4.49547518], [3.60978017, 1.20748818, 4.02966144], [3.05594182, 0.21618828, 3.91388548], [3.34493953, 1.20986655, 3.72562039], [3.50065397, 0.86706182, 4.10101938], [2.80825681, 0.50401564, 3.66383713], [3.27800809, 0.66826437, 3.94496718], [3.58990876, 0.68658071, 4.51061335], [2.55934697, 0.47945627, 3.57996434], [2.96493153, 0.36345425, 3.98817445], [2.55682739, 0.99023912, 3.88431906], [2.8279719 , 0.22683089, 3.79088782], [3.05970831, 0.2947186 , 3.89539875], [2.95425291, 0.25361098, 3.88085622], [2.87745051, 0.65019824, 4.09851673], [2.73238773, 0.80138328, 4.01796142], [2.73361981, 0.52309257, 3.57350896], [4.11853014, 1.57658655, 4.5037664 ], [4.22845606, 1.87652483, 4.4465301 ], [2.71452112, 0.76858489, 3.97906378], [2.86508665, 0.54896332, 4.01986385], [3.0573692 , 0.63079314, 3.80064093], [3.40284985, 0.45982568, 4.25136846], [3.00742655, 1.2336976 , 4.42052558], [2.95472117, 0.14580827, 3.90865188], [3.12324651, 0.20261743, 4.01192633], [2.90164193, 2.67055552, 4.64398605], [3.15411688, 0.90927099, 4.42154566], [2.8613548 , 0.50081008, 3.70483773], [3.34606471, 0.92159916, 3.9078554 ], [2.65231058, 1.01946042, 4.01421067], [3.53206587, 0.86953764, 4.14238152], [2.99813103, 0.72275914, 4.23577398], [3.34116935, 0.72324305, 3.97409784], [2.90222887, 0.30295342, 3.97223984], [1.9003878 , 3.43619989, 0.95288059], [1.41851492, 2.97232682, 0.99352148], [1.68457079, 3.51850037, 0.72661726], [0.96940962, 3.33264308, 2.69898424], [0.9112523 , 3.35747592, 1.11074501], [0.35721918, 2.77550662, 1.8143491 ], [1.59351202, 3.01808184, 1.00650285], [1.50213315, 2.77360088, 3.31296552], [1.11632078, 3.21148368, 1.14114175], [0.77921299, 2.66294828, 2.42994048], [1.97194958, 3.62389817, 3.73666782], [0.77530513, 2.70011145, 1.45918639], [1.25941769, 3.53658932, 2.74268279], [0.66155141, 2.98813829, 1.28976474], [0.73833453, 2.32311723, 2.05251547], [1.46572707, 3.14311522, 0.98780965], [0.80185102, 2.68234835, 1.67700171], [0.568386 , 2.63954211, 2.12682734], [1.19987895, 3.97369206, 2.33743839], [0.67881532, 2.87494798, 2.46667974], [1.34222961, 3.03853641, 1.1880022 ], [0.53061062, 2.8022861 , 1.63233668], [0.79234309, 3.68305664, 1.65142259], [0.57371215, 2.96833851, 1.54593744], [0.90589785, 2.9760862 , 1.2933375 ], [1.22490527, 3.13002382, 1.03085926], [1.26783271, 3.56679427, 1.09304603], [1.42114042, 3.5903606 , 0.52050254], [0.58974672, 2.93839428, 1.34712856], [0.76432091, 2.58203512, 2.44164622], [0.89738242, 2.99796537, 2.69027665], [0.98549851, 2.92597852, 2.76965187], [0.3921368 , 2.68907313, 2.02829879], [0.54223583, 3.42215998, 1.4211892 ], [0.90567816, 2.62771445, 1.88799766], [1.70872911, 2.75915071, 1.39853465], [1.48190142, 3.30075052, 0.78009974], [1.06129323, 3.73017167, 2.2083069 ], [0.81863359, 2.37943811, 1.87666989], [0.599882 , 2.98789866, 2.41035271], [0.4914813 , 2.89079656, 2.26782134], [0.84409423, 2.86642713, 1.25085451], [0.38941349, 2.86642575, 2.11791607], [1.53271026, 2.96966239, 3.35089399], [0.30831638, 2.77003779, 2.05312152], [0.81726253, 2.38255534, 1.83091351], [0.56428027, 2.55559903, 1.80454586], [0.72672271, 2.8455521 , 1.39825227], [1.28805849, 2.56987887, 3.06324547], [0.38163798, 2.64007308, 1.89861511], [2.31271244, 4.24274589, 1.0584579 ], [0.76585766, 3.57067982, 1.5185265 ], [2.14762671, 4.44150237, 0.52472 ], [1.17645413, 3.69480186, 0.77236486], [1.73594932, 4.11613683, 0.53031563], [2.78128346, 5.03326801, 1.2022172 ], [1.22550604, 3.3503222 , 2.74462238], [2.2426558 , 4.577021 , 0.92275933], [1.50462864, 4.363498 , 1.40314162], [3.22975724, 4.79334275, 1.48323372], [1.71837714, 3.62749566, 0.4787491 ], [1.10409694, 3.89360823, 1.0325986 ], [1.80475907, 4.1132966 , 0.27818948], [0.94858807, 3.82688169, 1.91870424], [1.39433359, 3.91538879, 1.49910975], [1.90677079, 3.89835633, 0.68622715], [1.39713702, 3.70128288, 0.46463058], [3.85224062, 5.18341242, 2.10127163], [2.95786451, 5.58136629, 1.83092395], [1.17790381, 4.02615768, 2.37017622], [2.27442972, 4.31907679, 0.52540209], [0.91211061, 3.4288432 , 1.62249456], [2.77937737, 5.19031307, 1.47042293], [0.84735471, 3.64273089, 1.15814207], [2.15695444, 4.00723617, 0.520093 ], [2.33581345, 4.2637671 , 0.66660166], [0.79774043, 3.45930032, 1.08324891], [1.022307 , 3.27575645, 0.94925151], [1.3842265 , 4.05342943, 0.84098317], [2.03854964, 4.1585729 , 0.75748198], [2.28297732, 4.71100584, 1.07124861], [3.88774921, 5.12224641, 2.17345728], [1.47357101, 4.13401784, 0.87682321], [0.7964005 , 3.39830644, 1.11534598], [0.80521086, 3.63719075, 1.59782917], [2.8607372 , 5.08776655, 1.25982873], [2.3101089 , 4.00416552, 1.07214028], [1.46990247, 3.58815834, 0.51434392], [0.97017134, 3.19454679, 1.0762733 ], [1.97333575, 4.09907253, 0.23050145], [2.07939567, 4.28416057, 0.57373487], [2.06609741, 4.17402084, 0.51130902], [0.76585766, 3.57067982, 1.5185265 ], [2.24723796, 4.32128686, 0.54141867], [2.42521977, 4.3480018 , 0.85128501], [1.82594618, 4.1240495 , 0.52475835], [1.03093862, 3.97564407, 1.52100812], [1.44892686, 3.7539635 , 0.44371189], [2.17585453, 3.7969924 , 1.08437101], [1.00508668, 3.25638099, 1.13739231]]) .. GENERATED FROM PYTHON SOURCE LINES 64-66 We compute every step and compare ONNX and scikit-learn outputs. .. GENERATED FROM PYTHON SOURCE LINES 66-88 .. code-block:: default for step in steps: print("----------------------------") print(step["model"]) onnx_step = step["onnx_step"] sess = InferenceSession( onnx_step.SerializeToString(), providers=["CPUExecutionProvider"] ) onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)}) onnx_output = onnx_outputs[-1] skl_outputs = step["model"]._debug.outputs["transform"] # comparison diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max() print("difference", diff) # That was the first way: dynamically overwrite # every method transform or predict in a scikit-learn # pipeline to capture the input and output of every step, # compare them to the output produced by truncated ONNX # graphs built from the first one. # .. rst-class:: sphx-glr-script-out .. code-block:: none ---------------------------- StandardScaler() difference 4.799262827148709e-07 ---------------------------- KMeans(n_clusters=3, n_init=3) difference 1.095537650763756e-06 .. GENERATED FROM PYTHON SOURCE LINES 89-96 Python runtime to look into every node ++++++++++++++++++++++++++++++++++++++ The python runtime may be useful to easily look into every node of the ONNX graph. This option can be used to check when the computation fails due to nan values or a dimension mismatch. .. GENERATED FROM PYTHON SOURCE LINES 96-103 .. code-block:: default onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17) oinf = ReferenceEvaluator(onx, verbose=1) oinf.run(None, {"X": X[:2].astype(numpy.float32)}) .. rst-class:: sphx-glr-script-out .. code-block:: none [array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ], [2.675508 , 0.99604493, 4.017933 ]], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 104-105 And to get a sense of the intermediate results. .. GENERATED FROM PYTHON SOURCE LINES 105-111 .. code-block:: default oinf = ReferenceEvaluator(onx, verbose=3) oinf.run(None, {"X": X[:2].astype(numpy.float32)}) # This way is usually better if you need to investigate # issues within the code of the runtime for an operator. .. rst-class:: sphx-glr-script-out .. code-block:: none +C Ad_Addcst: float32:(3,) in [0.9830552339553833, 5.035177230834961] +C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1359702348709106] +C Mu_Mulcst: float32:(1,) in [0.0, 0.0] +I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568] Scaler(X) -> variable + variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646] ReduceSumSquare(variable) -> Re_reduced0 + Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248] Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0 + Mu_C0: float32:(2, 1) in [0.0, 0.0] Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0 + Ge_Y0: float32:(2, 3) in [-10.366023063659668, 7.967348575592041] Add(Re_reduced0, Ge_Y0) -> Ad_C01 + Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.817853927612305] Add(Ad_Addcst, Ad_C01) -> Ad_C0 + Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.143783569335938] ArgMin(Ad_C0) -> label + label: int64:(2,) in [1, 1] Sqrt(Ad_C0) -> scores + scores: float32:(2, 3) in [0.2129589319229126, 4.017932891845703] [array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ], [2.675508 , 0.99604493, 4.017933 ]], dtype=float32)] .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.291 seconds) .. _sphx_glr_download_auto_tutorial_plot_fbegin_investigate.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_fbegin_investigate.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_fbegin_investigate.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_