Convert a pipeline with a XGBoost model#

sklearn-onnx only converts scikit-learn models into ONNX but many libraries implement scikit-learn API so that their models can be included in a scikit-learn pipeline. This example considers a pipeline including a XGBoost model. sklearn-onnx can convert the whole pipeline as long as it knows the converter associated to a XGBClassifier. Let’s see how to do it.

Train a XGBoost classifier#

import os
import numpy
import matplotlib.pyplot as plt
import onnx
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnxruntime as rt
import sklearn
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import xgboost
from xgboost import XGBClassifier
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,
)  # noqa
import onnxmltools
from onnxmltools.convert.xgboost.operator_converters.XGBoost import (
    convert_xgboost,
)  # noqa
import onnxmltools.convert.common.data_types

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([("scaler", StandardScaler()), ("lgbm", XGBClassifier(n_estimators=3))])
pipe.fit(X, y)

# The conversion fails but it is expected.

try:
    convert_sklearn(
        pipe,
        "pipeline_xgboost",
        [("input", FloatTensorType([None, 2]))],
        target_opset={"": 12, "ai.onnx.ml": 2},
    )
except Exception as e:
    print(e)

# The error message tells no converter was found
# for XGBoost models. By default, *sklearn-onnx*
# only handles models from *scikit-learn* but it can
# be extended to every model following *scikit-learn*
# API as long as the module knows there exists a converter
# for every model used in a pipeline. That's why
# we need to register a converter.
Unable to find a shape calculator for type '<class 'xgboost.sklearn.XGBClassifier'>'.
It usually means the pipeline being converted contains a
transformer or a predictor with no corresponding converter
implemented in sklearn-onnx. If the converted is implemented
in another library, you need to register
the converted so that it can be used by sklearn-onnx (function
update_registered_converter). If the model is not yet covered
by sklearn-onnx, you may raise an issue to
https://github.com/onnx/sklearn-onnx/issues
to get the converter implemented or even contribute to the
project. If the model is a custom model, a new converter must
be implemented. Examples can be found in the gallery.

Register the converter for XGBClassifier#

The converter is implemented in onnxmltools: onnxmltools…XGBoost.py. and the shape calculator: onnxmltools…Classifier.py.

Then we import the converter and shape calculator.

Let’s register the new converter.

update_registered_converter(
    XGBClassifier,
    "XGBoostXGBClassifier",
    calculate_linear_classifier_output_shapes,
    convert_xgboost,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

Convert again#

model_onnx = convert_sklearn(
    pipe,
    "pipeline_xgboost",
    [("input", FloatTensorType([None, 2]))],
    target_opset={"": 12, "ai.onnx.ml": 2},
)

# And save.
with open("pipeline_xgboost.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())

Compare the predictions#

Predictions with XGBoost.

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))
predict [2 2 2 0 0]
predict_proba [[0.17929651 0.23665468 0.5840488 ]]

Predictions with onnxruntime.

sess = rt.InferenceSession("pipeline_xgboost.onnx", providers=["CPUExecutionProvider"])
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])
predict [2 2 2 0 0]
predict_proba [{0: 0.17929650843143463, 1: 0.23665468394756317, 2: 0.5840488076210022}]

Display the ONNX graph#

pydot_graph = GetPydotGraph(
    model_onnx.graph,
    name=model_onnx.graph.name,
    rankdir="TB",
    node_producer=GetOpNodeProducer(
        "docstring", color="yellow", fillcolor="yellow", style="filled"
    ),
)
pydot_graph.write_dot("pipeline.dot")

os.system("dot -O -Gdpi=300 -Tpng pipeline.dot")

image = plt.imread("pipeline.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
plot pipeline xgboost
(-0.5, 2485.5, 2558.5, -0.5)

Versions used for this example

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
print("onnxmltools: ", onnxmltools.__version__)
print("xgboost: ", xgboost.__version__)
numpy: 1.26.2
scikit-learn: 1.5.dev0
onnx:  1.16.0
onnxruntime:  1.17.0+cu118
skl2onnx:  1.17.0
onnxmltools:  1.13.0
xgboost:  2.0.3

Total running time of the script: (0 minutes 2.227 seconds)

Gallery generated by Sphinx-Gallery