Note
Go to the end to download the full example code
Benchmark a pipeline#
The following example checks up on every step in a pipeline, compares and benchmarks the predictions.
Create a pipeline#
We reuse the pipeline implemented in example
Pipelining: chaining a PCA and a logistic regression.
There is one change because
ONNX-ML Imputer
does not handle string type. This cannot be part of the final ONNX pipeline
and must be removed. Look for comment starting with ---
below.
import skl2onnx
import onnx
import sklearn
import numpy
from skl2onnx.helpers import collect_intermediate_steps
from timeit import timeit
from skl2onnx.helpers import compare_objects
import onnxruntime as rt
from onnxconverter_common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
logistic = LogisticRegression()
pca = PCA()
pipe = Pipeline(steps=[("pca", pca), ("logistic", logistic)])
digits = datasets.load_digits()
X_digits = digits.data[:1000]
y_digits = digits.target[:1000]
pipe.fit(X_digits, y_digits)
/home/xadupre/github/scikit-learn/sklearn/linear_model/_logistic.py:472: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
Conversion to ONNX#
initial_types = [("input", FloatTensorType((None, X_digits.shape[1])))]
model_onnx = convert_sklearn(pipe, initial_types=initial_types, target_opset=12)
sess = rt.InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
print("skl predict_proba")
print(pipe.predict_proba(X_digits[:2]))
onx_pred = sess.run(None, {"input": X_digits[:2].astype(np.float32)})[1]
df = pd.DataFrame(onx_pred)
print("onnx predict_proba")
print(df.values)
skl predict_proba
[[9.99998536e-01 5.99063158e-19 3.48548953e-10 1.55765726e-08
3.32559745e-10 1.21314653e-06 3.98959930e-08 1.22513839e-07
2.23871272e-08 4.98148509e-08]
[1.47648437e-14 9.99999301e-01 1.05811967e-10 7.49298733e-13
2.48627417e-07 8.75686484e-12 5.39025135e-11 2.95899938e-11
4.50528833e-07 1.30607478e-13]]
onnx predict_proba
[[9.99998569e-01 5.99062501e-19 3.48550355e-10 1.55766493e-08
3.32561811e-10 1.21315134e-06 3.98961930e-08 1.22514706e-07
2.23872494e-08 4.98151529e-08]
[1.47648956e-14 9.99999285e-01 1.05811991e-10 7.49297488e-13
2.48627885e-07 8.75685548e-12 5.39024415e-11 2.95899520e-11
4.50529058e-07 1.30607344e-13]]
Comparing outputs#
compare_objects(pipe.predict_proba(X_digits[:2]), onx_pred)
# No exception so they are the same.
Benchmarks#
print("scikit-learn")
print(timeit("pipe.predict_proba(X_digits[:1])", number=10000, globals=globals()))
print("onnxruntime")
print(
timeit(
"sess.run(None, {'input': X_digits[:1].astype(np.float32)})[1]",
number=10000,
globals=globals(),
)
)
scikit-learn
3.2031071999999767
onnxruntime
0.26938670000004095
Intermediate steps#
Let’s imagine the final output is wrong and we need to look into each component of the pipeline which one is failing. The following method modifies the scikit-learn pipeline to steal the intermediate outputs and produces an smaller ONNX graph for every operator.
steps = collect_intermediate_steps(pipe, "pipeline", initial_types)
assert len(steps) == 2
pipe.predict_proba(X_digits[:2])
for i, step in enumerate(steps):
onnx_step = step["onnx_step"]
sess = rt.InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"input": X_digits[:2].astype(np.float32)})
skl_outputs = step["model"]._debug.outputs
if "transform" in skl_outputs:
compare_objects(skl_outputs["transform"], onnx_outputs[0])
print("benchmark", step["model"].__class__)
print("scikit-learn")
print(
timeit(
"step['model'].transform(X_digits[:1])", number=10000, globals=globals()
)
)
else:
compare_objects(skl_outputs["predict_proba"], onnx_outputs[1])
print("benchmark", step["model"].__class__)
print("scikit-learn")
print(
timeit(
"step['model'].predict_proba(X_digits[:1])",
number=10000,
globals=globals(),
)
)
print("onnxruntime")
print(
timeit(
"sess.run(None, {'input': X_digits[:1].astype(np.float32)})",
number=10000,
globals=globals(),
)
)
benchmark <class 'sklearn.decomposition._pca.PCA'>
scikit-learn
0.8116694999998799
onnxruntime
0.19065660000001117
benchmark <class 'sklearn.linear_model._logistic.LogisticRegression'>
scikit-learn
1.1337754000001041
onnxruntime
0.20516539999994166
Versions used for this example
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx: 1.15.0
onnxruntime: 1.16.0+cu118
skl2onnx: 1.15.0
Total running time of the script: (0 minutes 6.748 seconds)