Benchmark a pipeline#

The following example checks up on every step in a pipeline, compares and benchmarks the predictions.

Create a pipeline#

We reuse the pipeline implemented in example Pipelining: chaining a PCA and a logistic regression. There is one change because ONNX-ML Imputer does not handle string type. This cannot be part of the final ONNX pipeline and must be removed. Look for comment starting with --- below.

import skl2onnx
import onnx
import sklearn
import numpy
from skl2onnx.helpers import collect_intermediate_steps
from timeit import timeit
from skl2onnx.helpers import compare_objects
import onnxruntime as rt
from onnxconverter_common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import numpy as np
import pandas as pd

from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

logistic = LogisticRegression()
pca = PCA()
pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)])

digits = datasets.load_digits()
X_digits = digits.data[:1000]
y_digits = digits.target[:1000]

pipe.fit(X_digits, y_digits)
/home/xadupre/github/scikit-learn/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
Pipeline(steps=[('pca', PCA()), ('logistic', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Conversion to ONNX#

initial_types = [('input', FloatTensorType((None, X_digits.shape[1])))]
model_onnx = convert_sklearn(pipe, initial_types=initial_types,
                             target_opset=12)

sess = rt.InferenceSession(model_onnx.SerializeToString(),
                           providers=["CPUExecutionProvider"])
print("skl predict_proba")
print(pipe.predict_proba(X_digits[:2]))
onx_pred = sess.run(None, {'input': X_digits[:2].astype(np.float32)})[1]
df = pd.DataFrame(onx_pred)
print("onnx predict_proba")
print(df.values)
skl predict_proba
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
  3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
  2.23871272e-08 4.98148509e-08]
 [1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
  2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
  4.50528833e-07 1.30607478e-13]]
onnx predict_proba
[[9.99998569e-01 5.99062501e-19 3.48550355e-10 1.55766493e-08
  3.32561811e-10 1.21315134e-06 3.98961930e-08 1.22514706e-07
  2.23872494e-08 4.98151529e-08]
 [1.47648956e-14 9.99999285e-01 1.05811991e-10 7.49297488e-13
  2.48627885e-07 8.75685548e-12 5.39024415e-11 2.95899520e-11
  4.50529058e-07 1.30607344e-13]]

Comparing outputs#

compare_objects(pipe.predict_proba(X_digits[:2]), onx_pred)
# No exception so they are the same.

Benchmarks#

print("scikit-learn")
print(timeit("pipe.predict_proba(X_digits[:1])",
             number=10000, globals=globals()))
print("onnxruntime")
print(timeit("sess.run(None, {'input': X_digits[:1].astype(np.float32)})[1]",
             number=10000, globals=globals()))
scikit-learn
2.355312850000246
onnxruntime
0.29348953099997743

Intermediate steps#

Let’s imagine the final output is wrong and we need to look into each component of the pipeline which one is failing. The following method modifies the scikit-learn pipeline to steal the intermediate outputs and produces an smaller ONNX graph for every operator.

steps = collect_intermediate_steps(
    pipe, "pipeline", initial_types)

assert len(steps) == 2

pipe.predict_proba(X_digits[:2])

for i, step in enumerate(steps):
    onnx_step = step['onnx_step']
    sess = rt.InferenceSession(onnx_step.SerializeToString(),
                               providers=["CPUExecutionProvider"])
    onnx_outputs = sess.run(None, {'input': X_digits[:2].astype(np.float32)})
    skl_outputs = step['model']._debug.outputs
    if 'transform' in skl_outputs:
        compare_objects(skl_outputs['transform'], onnx_outputs[0])
        print("benchmark", step['model'].__class__)
        print("scikit-learn")
        print(timeit("step['model'].transform(X_digits[:1])",
                     number=10000, globals=globals()))
    else:
        compare_objects(skl_outputs['predict_proba'], onnx_outputs[1])
        print("benchmark", step['model'].__class__)
        print("scikit-learn")
        print(timeit("step['model'].predict_proba(X_digits[:1])",
                     number=10000, globals=globals()))
    print("onnxruntime")
    print(timeit("sess.run(None, {'input': X_digits[:1].astype(np.float32)})",
                 number=10000, globals=globals()))
benchmark <class 'sklearn.decomposition._pca.PCA'>
scikit-learn
0.6831115730001329
onnxruntime
0.16402971700017588
benchmark <class 'sklearn.linear_model._logistic.LogisticRegression'>
scikit-learn
1.4586870539997108
onnxruntime
0.15432031699992876

Versions used for this example

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.3.dev0
onnx:  1.14.0
onnxruntime:  1.15.0+cpu
skl2onnx:  1.14.0

Total running time of the script: ( 0 minutes 6.685 seconds)

Gallery generated by Sphinx-Gallery