Convert a pipeline with a LightGBM classifier#

sklearn-onnx only converts scikit-learn models into ONNX but many libraries implement scikit-learn API so that their models can be included in a scikit-learn pipeline. This example considers a pipeline including a LightGBM model. sklearn-onnx can convert the whole pipeline as long as it knows the converter associated to a LGBMClassifier. Let’s see how to do it.

Train a LightGBM classifier#

import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_classifier_output_shapes,
)  # noqa
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm,
)  # noqa
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline(
    [("scaler", StandardScaler()), ("lgbm", LGBMClassifier(n_estimators=3))]
)
pipe.fit(X, y)
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000030 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 47
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 2
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
Pipeline(steps=[('scaler', StandardScaler()),
                ('lgbm', LGBMClassifier(n_estimators=3))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Register the converter for LGBMClassifier#

The converter is implemented in onnxmltools: onnxmltools…LightGbm.py. and the shape calculator: onnxmltools…Classifier.py.

update_registered_converter(
    LGBMClassifier,
    "LightGbmLGBMClassifier",
    calculate_linear_classifier_output_shapes,
    convert_lightgbm,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

Convert again#

model_onnx = convert_sklearn(
    pipe,
    "pipeline_lightgbm",
    [("input", FloatTensorType([None, 2]))],
    target_opset={"": 12, "ai.onnx.ml": 2},
)

# And save.
with open("pipeline_lightgbm.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())

Compare the predictions#

Predictions with LightGbm.

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))
predict [1 2 0 0 0]
predict_proba [[0.37745098 0.38299882 0.2395502 ]]

Predictions with onnxruntime.

sess = rt.InferenceSession("pipeline_lightgbm.onnx", providers=["CPUExecutionProvider"])

pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])
predict [1 2 0 0 0]
predict_proba [{0: 0.37745100259780884, 1: 0.38299882411956787, 2: 0.23955020308494568}]

Total running time of the script: (0 minutes 0.066 seconds)

Gallery generated by Sphinx-Gallery