Note
Go to the end to download the full example code
Investigate a pipeline#
The following example shows how to look into a converted models and easily find errors at every step of the pipeline.
Create a pipeline#
We reuse the pipeline implemented in example
Pipelining: chaining a PCA and a logistic regression.
There is one change because
ONNX-ML Imputer
does not handle string type. This cannot be part of the final ONNX pipeline
and must be removed. Look for comment starting with ---
below.
import skl2onnx
import onnx
import sklearn
import numpy
import pickle
from skl2onnx.helpers import collect_intermediate_steps
import onnxruntime as rt
from onnxconverter_common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
pipe = Pipeline(steps=[("pca", PCA()), ("logistic", LogisticRegression())])
digits = datasets.load_digits()
X_digits = digits.data[:1000]
y_digits = digits.target[:1000]
pipe.fit(X_digits, y_digits)
/home/xadupre/github/scikit-learn/sklearn/linear_model/_logistic.py:472: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
Conversion to ONNX#
initial_types = [("input", FloatTensorType((None, X_digits.shape[1])))]
model_onnx = convert_sklearn(pipe, initial_types=initial_types, target_opset=12)
sess = rt.InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
print("skl predict_proba")
print(pipe.predict_proba(X_digits[:2]))
onx_pred = sess.run(None, {"input": X_digits[:2].astype(np.float32)})[1]
df = pd.DataFrame(onx_pred)
print("onnx predict_proba")
print(df.values)
skl predict_proba
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
2.23871272e-08 4.98148509e-08]
[1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
4.50528833e-07 1.30607478e-13]]
onnx predict_proba
[[9.99998569e-01 5.99062501e-19 3.48550355e-10 1.55766493e-08
3.32561811e-10 1.21315134e-06 3.98961930e-08 1.22514706e-07
2.23872494e-08 4.98151529e-08]
[1.47648956e-14 9.99999285e-01 1.05811991e-10 7.49297488e-13
2.48627885e-07 8.75685548e-12 5.39024415e-11 2.95899520e-11
4.50529058e-07 1.30607344e-13]]
Intermediate steps#
Let’s imagine the final output is wrong and we need to look into each component of the pipeline which one is failing. The following method modifies the scikit-learn pipeline to steal the intermediate outputs and produces an smaller ONNX graph for every operator.
steps = collect_intermediate_steps(pipe, "pipeline", initial_types)
assert len(steps) == 2
pipe.predict_proba(X_digits[:2])
for i, step in enumerate(steps):
onnx_step = step["onnx_step"]
sess = rt.InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"input": X_digits[:2].astype(np.float32)})
skl_outputs = step["model"]._debug.outputs
print("step 1", type(step["model"]))
print("skl outputs")
print(skl_outputs)
print("onnx outputs")
print(onnx_outputs)
step 1 <class 'sklearn.decomposition._pca.PCA'>
skl outputs
{'transform': array([[-9.78697129e+00, 7.22639567e+00, -2.16935601e+01,
1.13765854e+01, -3.54566122e+00, -5.59543345e+00,
4.71459904e+00, 4.29410146e+00, -5.71520266e+00,
-3.31533698e+00, -3.42040920e-01, 2.90474751e+00,
-3.18177631e-01, -6.66363079e-01, -2.82714171e+00,
5.91632481e+00, -9.69544780e-01, 1.92676767e+00,
1.71450677e+00, 9.60454853e-01, 3.81570991e-01,
-1.37130203e+00, 4.29353551e+00, 2.32392659e+00,
7.13256034e-01, 3.00982060e+00, -1.98303620e+00,
-4.81811365e-01, -1.90930400e-01, 2.03950266e+00,
1.59803428e+00, -1.46831581e+00, -1.70903280e+00,
7.93109126e-02, -1.62244448e-01, 5.10619572e-02,
-6.63308841e-01, 1.35869345e+00, -1.03930533e+00,
2.09485311e+00, 2.15669105e+00, -7.78040093e-02,
-4.01347652e-02, 8.40159293e-01, -4.74891758e-01,
-1.14564701e-01, -5.31817617e-02, -6.87010227e-01,
-1.29090165e-01, 2.12032919e-01, 3.63901656e-01,
-1.29285214e-01, -8.14384613e-02, -3.82919696e-02,
-9.76885583e-03, -1.39046240e-02, 1.59100433e-03,
-2.87444919e-03, 5.75119957e-03, 1.85595427e-03,
-5.00911047e-03, -2.53068224e-15, -6.30369386e-16,
-9.16970102e-16],
[ 1.54267314e+01, -4.91291516e+00, 1.74676972e+01,
-1.13960509e+01, 5.64555024e+00, -5.73696034e+00,
-2.08026490e+00, 5.23721537e+00, 3.37859393e+00,
3.60754149e+00, 2.90967608e+00, -3.75628331e+00,
-1.21238177e+00, -5.21796290e+00, -4.95051435e+00,
-4.01835168e+00, -2.97046115e+00, -5.64772188e+00,
5.61898054e+00, -4.32016109e+00, 1.97701819e+00,
-3.39030059e+00, -5.67779351e-01, 6.70107684e-01,
6.31443589e+00, 8.65991552e-01, -1.58633137e-01,
-3.52940090e+00, -6.81737794e-01, 2.47187038e+00,
1.21588602e+00, -2.22346979e+00, 1.37364649e+00,
-1.79895009e+00, 3.03710592e+00, -2.63278986e+00,
3.68918985e+00, -6.08509461e-01, 2.45039011e-01,
-6.63479061e-01, -1.50727140e+00, 1.10449110e+00,
-4.58384385e-01, 3.40399894e-01, -2.67878895e-01,
-1.87647893e+00, -2.04332870e-01, 4.61919057e-01,
-2.44538953e-02, 8.66380644e-04, -7.56583008e-02,
1.91237218e-01, -4.73950435e-02, 2.74122911e-02,
4.32524378e-03, -3.66956686e-03, -1.88790754e-03,
5.22119207e-03, -1.86775268e-03, -5.07041881e-03,
-1.70805502e-03, 1.87088367e-15, -3.01154459e-15,
2.24048193e-16]])}
onnx outputs
[array([[-9.78696918e+00, 7.22639418e+00, -2.16935596e+01,
1.13765850e+01, -3.54566121e+00, -5.59543371e+00,
4.71459913e+00, 4.29410172e+00, -5.71520233e+00,
-3.31533718e+00, -3.42040539e-01, 2.90474844e+00,
-3.18177342e-01, -6.66362762e-01, -2.82714128e+00,
5.91632557e+00, -9.69543815e-01, 1.92676806e+00,
1.71450746e+00, 9.60454881e-01, 3.81571263e-01,
-1.37130213e+00, 4.29353619e+00, 2.32392645e+00,
7.13255882e-01, 3.00982118e+00, -1.98303699e+00,
-4.81811404e-01, -1.90929934e-01, 2.03950286e+00,
1.59803450e+00, -1.46831572e+00, -1.70903301e+00,
7.93112069e-02, -1.62244260e-01, 5.10617606e-02,
-6.63308799e-01, 1.35869288e+00, -1.03930473e+00,
2.09485388e+00, 2.15669155e+00, -7.78041705e-02,
-4.01349142e-02, 8.40159237e-01, -4.74891722e-01,
-1.14564866e-01, -5.31819277e-02, -6.87010169e-01,
-1.29090086e-01, 2.12032884e-01, 3.63901585e-01,
-1.29285216e-01, -8.14384818e-02, -3.82919535e-02,
-9.76885669e-03, -1.39046200e-02, 1.59100525e-03,
-2.87444773e-03, 5.75120188e-03, 1.85595278e-03,
-5.00911009e-03, -2.53068203e-15, -6.30369331e-16,
-9.16970128e-16],
[ 1.54267330e+01, -4.91291523e+00, 1.74676971e+01,
-1.13960505e+01, 5.64554977e+00, -5.73695993e+00,
-2.08026457e+00, 5.23721600e+00, 3.37859321e+00,
3.60754204e+00, 2.90967607e+00, -3.75628328e+00,
-1.21238220e+00, -5.21796322e+00, -4.95051479e+00,
-4.01835155e+00, -2.97046089e+00, -5.64772224e+00,
5.61898088e+00, -4.32016134e+00, 1.97701883e+00,
-3.39030147e+00, -5.67779541e-01, 6.70108199e-01,
6.31443739e+00, 8.65990937e-01, -1.58633217e-01,
-3.52940059e+00, -6.81736946e-01, 2.47186923e+00,
1.21588576e+00, -2.22346997e+00, 1.37364638e+00,
-1.79894984e+00, 3.03710651e+00, -2.63278937e+00,
3.68918991e+00, -6.08509481e-01, 2.45039046e-01,
-6.63479507e-01, -1.50727105e+00, 1.10449100e+00,
-4.58384484e-01, 3.40399802e-01, -2.67878950e-01,
-1.87647831e+00, -2.04333529e-01, 4.61919039e-01,
-2.44537946e-02, 8.66464688e-04, -7.56583288e-02,
1.91237196e-01, -4.73950393e-02, 2.74122953e-02,
4.32524411e-03, -3.66956298e-03, -1.88790704e-03,
5.22119273e-03, -1.86775194e-03, -5.07041626e-03,
-1.70805526e-03, 1.87088423e-15, -3.01154475e-15,
2.24048182e-16]], dtype=float32)]
step 1 <class 'sklearn.linear_model._logistic.LogisticRegression'>
skl outputs
{'decision_function': array([[9.99998536e-01, 5.99063158e-19, 3.48548953e-10, 1.55765726e-08,
3.32559745e-10, 1.21314653e-06, 3.98959930e-08, 1.22513839e-07,
2.23871272e-08, 4.98148509e-08],
[1.47648437e-14, 9.99999301e-01, 1.05811967e-10, 7.49298733e-13,
2.48627417e-07, 8.75686484e-12, 5.39025135e-11, 2.95899938e-11,
4.50528833e-07, 1.30607478e-13]]), 'predict_proba': array([[9.99998536e-01, 5.99063158e-19, 3.48548953e-10, 1.55765726e-08,
3.32559745e-10, 1.21314653e-06, 3.98959930e-08, 1.22513839e-07,
2.23871272e-08, 4.98148509e-08],
[1.47648437e-14, 9.99999301e-01, 1.05811967e-10, 7.49298733e-13,
2.48627417e-07, 8.75686484e-12, 5.39025135e-11, 2.95899938e-11,
4.50528833e-07, 1.30607478e-13]])}
onnx outputs
[array([0, 1], dtype=int64), array([[9.9999857e-01, 5.9906250e-19, 3.4855036e-10, 1.5576649e-08,
3.3256181e-10, 1.2131513e-06, 3.9896193e-08, 1.2251471e-07,
2.2387249e-08, 4.9815153e-08],
[1.4764896e-14, 9.9999928e-01, 1.0581199e-10, 7.4929749e-13,
2.4862788e-07, 8.7568555e-12, 5.3902442e-11, 2.9589952e-11,
4.5052906e-07, 1.3060734e-13]], dtype=float32)]
Pickle#
Each steps is a separate model in the pipeline. It can be pickle independetly from the others. Attribute _debug contains all the information needed to replay the prediction of the model.
to_save = {
"model": steps[1]["model"],
"data_input": steps[1]["model"]._debug.inputs,
"data_output": steps[1]["model"]._debug.outputs,
"inputs": steps[1]["inputs"],
"outputs": steps[1]["outputs"],
}
del steps[1]["model"]._debug
with open("classifier.pkl", "wb") as f:
pickle.dump(to_save, f)
with open("classifier.pkl", "rb") as f:
restored = pickle.load(f)
print(restored["model"].predict_proba(restored["data_input"]["predict_proba"]))
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
2.23871272e-08 4.98148509e-08]
[1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
4.50528833e-07 1.30607478e-13]]
Versions used for this example
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx: 1.15.0
onnxruntime: 1.16.0+cu118
skl2onnx: 1.16.0
Total running time of the script: (0 minutes 0.279 seconds)