Investigate a pipeline#

The following example shows how to look into a converted models and easily find errors at every step of the pipeline.

Create a pipeline#

We reuse the pipeline implemented in example Pipelining: chaining a PCA and a logistic regression. There is one change because ONNX-ML Imputer does not handle string type. This cannot be part of the final ONNX pipeline and must be removed. Look for comment starting with --- below.

import skl2onnx
import onnx
import sklearn
import numpy
import pickle
from skl2onnx.helpers import collect_intermediate_steps
import onnxruntime as rt
from onnxconverter_common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import numpy as np
import pandas as pd

from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

pipe = Pipeline(steps=[("pca", PCA()), ("logistic", LogisticRegression())])

digits = datasets.load_digits()
X_digits = digits.data[:1000]
y_digits = digits.target[:1000]

pipe.fit(X_digits, y_digits)
/home/xadupre/github/scikit-learn/sklearn/linear_model/_logistic.py:472: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
Pipeline(steps=[('pca', PCA()), ('logistic', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Conversion to ONNX#

initial_types = [("input", FloatTensorType((None, X_digits.shape[1])))]
model_onnx = convert_sklearn(pipe, initial_types=initial_types, target_opset=12)

sess = rt.InferenceSession(
    model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
print("skl predict_proba")
print(pipe.predict_proba(X_digits[:2]))
onx_pred = sess.run(None, {"input": X_digits[:2].astype(np.float32)})[1]
df = pd.DataFrame(onx_pred)
print("onnx predict_proba")
print(df.values)
skl predict_proba
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
  3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
  2.23871272e-08 4.98148509e-08]
 [1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
  2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
  4.50528833e-07 1.30607478e-13]]
onnx predict_proba
[[9.99998569e-01 5.99062501e-19 3.48550355e-10 1.55766493e-08
  3.32561811e-10 1.21315134e-06 3.98961930e-08 1.22514706e-07
  2.23872494e-08 4.98151529e-08]
 [1.47648956e-14 9.99999285e-01 1.05811991e-10 7.49297488e-13
  2.48627885e-07 8.75685548e-12 5.39024415e-11 2.95899520e-11
  4.50529058e-07 1.30607344e-13]]

Intermediate steps#

Let’s imagine the final output is wrong and we need to look into each component of the pipeline which one is failing. The following method modifies the scikit-learn pipeline to steal the intermediate outputs and produces an smaller ONNX graph for every operator.

steps = collect_intermediate_steps(pipe, "pipeline", initial_types)

assert len(steps) == 2

pipe.predict_proba(X_digits[:2])

for i, step in enumerate(steps):
    onnx_step = step["onnx_step"]
    sess = rt.InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    onnx_outputs = sess.run(None, {"input": X_digits[:2].astype(np.float32)})
    skl_outputs = step["model"]._debug.outputs
    print("step 1", type(step["model"]))
    print("skl outputs")
    print(skl_outputs)
    print("onnx outputs")
    print(onnx_outputs)
step 1 <class 'sklearn.decomposition._pca.PCA'>
skl outputs
{'transform': array([[-9.78697129e+00,  7.22639567e+00, -2.16935601e+01,
         1.13765854e+01, -3.54566122e+00, -5.59543345e+00,
         4.71459904e+00,  4.29410146e+00, -5.71520266e+00,
        -3.31533698e+00, -3.42040920e-01,  2.90474751e+00,
        -3.18177631e-01, -6.66363079e-01, -2.82714171e+00,
         5.91632481e+00, -9.69544780e-01,  1.92676767e+00,
         1.71450677e+00,  9.60454853e-01,  3.81570991e-01,
        -1.37130203e+00,  4.29353551e+00,  2.32392659e+00,
         7.13256034e-01,  3.00982060e+00, -1.98303620e+00,
        -4.81811365e-01, -1.90930400e-01,  2.03950266e+00,
         1.59803428e+00, -1.46831581e+00, -1.70903280e+00,
         7.93109126e-02, -1.62244448e-01,  5.10619572e-02,
        -6.63308841e-01,  1.35869345e+00, -1.03930533e+00,
         2.09485311e+00,  2.15669105e+00, -7.78040093e-02,
        -4.01347652e-02,  8.40159293e-01, -4.74891758e-01,
        -1.14564701e-01, -5.31817617e-02, -6.87010227e-01,
        -1.29090165e-01,  2.12032919e-01,  3.63901656e-01,
        -1.29285214e-01, -8.14384613e-02, -3.82919696e-02,
        -9.76885583e-03, -1.39046240e-02,  1.59100433e-03,
        -2.87444919e-03,  5.75119957e-03,  1.85595427e-03,
        -5.00911047e-03, -2.53068224e-15, -6.30369386e-16,
        -9.16970102e-16],
       [ 1.54267314e+01, -4.91291516e+00,  1.74676972e+01,
        -1.13960509e+01,  5.64555024e+00, -5.73696034e+00,
        -2.08026490e+00,  5.23721537e+00,  3.37859393e+00,
         3.60754149e+00,  2.90967608e+00, -3.75628331e+00,
        -1.21238177e+00, -5.21796290e+00, -4.95051435e+00,
        -4.01835168e+00, -2.97046115e+00, -5.64772188e+00,
         5.61898054e+00, -4.32016109e+00,  1.97701819e+00,
        -3.39030059e+00, -5.67779351e-01,  6.70107684e-01,
         6.31443589e+00,  8.65991552e-01, -1.58633137e-01,
        -3.52940090e+00, -6.81737794e-01,  2.47187038e+00,
         1.21588602e+00, -2.22346979e+00,  1.37364649e+00,
        -1.79895009e+00,  3.03710592e+00, -2.63278986e+00,
         3.68918985e+00, -6.08509461e-01,  2.45039011e-01,
        -6.63479061e-01, -1.50727140e+00,  1.10449110e+00,
        -4.58384385e-01,  3.40399894e-01, -2.67878895e-01,
        -1.87647893e+00, -2.04332870e-01,  4.61919057e-01,
        -2.44538953e-02,  8.66380644e-04, -7.56583008e-02,
         1.91237218e-01, -4.73950435e-02,  2.74122911e-02,
         4.32524378e-03, -3.66956686e-03, -1.88790754e-03,
         5.22119207e-03, -1.86775268e-03, -5.07041881e-03,
        -1.70805502e-03,  1.87088367e-15, -3.01154459e-15,
         2.24048193e-16]])}
onnx outputs
[array([[-9.78696918e+00,  7.22639418e+00, -2.16935596e+01,
         1.13765850e+01, -3.54566121e+00, -5.59543371e+00,
         4.71459913e+00,  4.29410172e+00, -5.71520233e+00,
        -3.31533718e+00, -3.42040539e-01,  2.90474844e+00,
        -3.18177342e-01, -6.66362762e-01, -2.82714128e+00,
         5.91632557e+00, -9.69543815e-01,  1.92676806e+00,
         1.71450746e+00,  9.60454881e-01,  3.81571263e-01,
        -1.37130213e+00,  4.29353619e+00,  2.32392645e+00,
         7.13255882e-01,  3.00982118e+00, -1.98303699e+00,
        -4.81811404e-01, -1.90929934e-01,  2.03950286e+00,
         1.59803450e+00, -1.46831572e+00, -1.70903301e+00,
         7.93112069e-02, -1.62244260e-01,  5.10617606e-02,
        -6.63308799e-01,  1.35869288e+00, -1.03930473e+00,
         2.09485388e+00,  2.15669155e+00, -7.78041705e-02,
        -4.01349142e-02,  8.40159237e-01, -4.74891722e-01,
        -1.14564866e-01, -5.31819277e-02, -6.87010169e-01,
        -1.29090086e-01,  2.12032884e-01,  3.63901585e-01,
        -1.29285216e-01, -8.14384818e-02, -3.82919535e-02,
        -9.76885669e-03, -1.39046200e-02,  1.59100525e-03,
        -2.87444773e-03,  5.75120188e-03,  1.85595278e-03,
        -5.00911009e-03, -2.53068203e-15, -6.30369331e-16,
        -9.16970128e-16],
       [ 1.54267330e+01, -4.91291523e+00,  1.74676971e+01,
        -1.13960505e+01,  5.64554977e+00, -5.73695993e+00,
        -2.08026457e+00,  5.23721600e+00,  3.37859321e+00,
         3.60754204e+00,  2.90967607e+00, -3.75628328e+00,
        -1.21238220e+00, -5.21796322e+00, -4.95051479e+00,
        -4.01835155e+00, -2.97046089e+00, -5.64772224e+00,
         5.61898088e+00, -4.32016134e+00,  1.97701883e+00,
        -3.39030147e+00, -5.67779541e-01,  6.70108199e-01,
         6.31443739e+00,  8.65990937e-01, -1.58633217e-01,
        -3.52940059e+00, -6.81736946e-01,  2.47186923e+00,
         1.21588576e+00, -2.22346997e+00,  1.37364638e+00,
        -1.79894984e+00,  3.03710651e+00, -2.63278937e+00,
         3.68918991e+00, -6.08509481e-01,  2.45039046e-01,
        -6.63479507e-01, -1.50727105e+00,  1.10449100e+00,
        -4.58384484e-01,  3.40399802e-01, -2.67878950e-01,
        -1.87647831e+00, -2.04333529e-01,  4.61919039e-01,
        -2.44537946e-02,  8.66464688e-04, -7.56583288e-02,
         1.91237196e-01, -4.73950393e-02,  2.74122953e-02,
         4.32524411e-03, -3.66956298e-03, -1.88790704e-03,
         5.22119273e-03, -1.86775194e-03, -5.07041626e-03,
        -1.70805526e-03,  1.87088423e-15, -3.01154475e-15,
         2.24048182e-16]], dtype=float32)]
step 1 <class 'sklearn.linear_model._logistic.LogisticRegression'>
skl outputs
{'decision_function': array([[9.99998536e-01, 5.99063158e-19, 3.48548953e-10, 1.55765726e-08,
        3.32559745e-10, 1.21314653e-06, 3.98959930e-08, 1.22513839e-07,
        2.23871272e-08, 4.98148509e-08],
       [1.47648437e-14, 9.99999301e-01, 1.05811967e-10, 7.49298733e-13,
        2.48627417e-07, 8.75686484e-12, 5.39025135e-11, 2.95899938e-11,
        4.50528833e-07, 1.30607478e-13]]), 'predict_proba': array([[9.99998536e-01, 5.99063158e-19, 3.48548953e-10, 1.55765726e-08,
        3.32559745e-10, 1.21314653e-06, 3.98959930e-08, 1.22513839e-07,
        2.23871272e-08, 4.98148509e-08],
       [1.47648437e-14, 9.99999301e-01, 1.05811967e-10, 7.49298733e-13,
        2.48627417e-07, 8.75686484e-12, 5.39025135e-11, 2.95899938e-11,
        4.50528833e-07, 1.30607478e-13]])}
onnx outputs
[array([0, 1], dtype=int64), array([[9.9999857e-01, 5.9906250e-19, 3.4855036e-10, 1.5576649e-08,
        3.3256181e-10, 1.2131513e-06, 3.9896193e-08, 1.2251471e-07,
        2.2387249e-08, 4.9815153e-08],
       [1.4764896e-14, 9.9999928e-01, 1.0581199e-10, 7.4929749e-13,
        2.4862788e-07, 8.7568555e-12, 5.3902442e-11, 2.9589952e-11,
        4.5052906e-07, 1.3060734e-13]], dtype=float32)]

Pickle#

Each steps is a separate model in the pipeline. It can be pickle independetly from the others. Attribute _debug contains all the information needed to replay the prediction of the model.

to_save = {
    "model": steps[1]["model"],
    "data_input": steps[1]["model"]._debug.inputs,
    "data_output": steps[1]["model"]._debug.outputs,
    "inputs": steps[1]["inputs"],
    "outputs": steps[1]["outputs"],
}
del steps[1]["model"]._debug

with open("classifier.pkl", "wb") as f:
    pickle.dump(to_save, f)

with open("classifier.pkl", "rb") as f:
    restored = pickle.load(f)

print(restored["model"].predict_proba(restored["data_input"]["predict_proba"]))
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
  3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
  2.23871272e-08 4.98148509e-08]
 [1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
  2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
  4.50528833e-07 1.30607478e-13]]

Versions used for this example

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx:  1.15.0
onnxruntime:  1.16.0+cu118
skl2onnx:  1.16.0

Total running time of the script: (0 minutes 0.279 seconds)

Gallery generated by Sphinx-Gallery