Note
Go to the end to download the full example code.
Convert a pipeline with a LightGBM classifier¶
sklearn-onnx only converts scikit-learn models into ONNX but many libraries implement scikit-learn API so that their models can be included in a scikit-learn pipeline. This example considers a pipeline including a LightGBM model. sklearn-onnx can convert the whole pipeline as long as it knows the converter associated to a LGBMClassifier. Let’s see how to do it.
Train a LightGBM classifier¶
import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
) # noqa
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
convert_lightgbm,
) # noqa
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier
data = load_iris()
X = data.data[:, :2]
y = data.target
ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()
pipe = Pipeline(
[("scaler", StandardScaler()), ("lgbm", LGBMClassifier(n_estimators=3))]
)
pipe.fit(X, y)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000695 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 47
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 2
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
Register the converter for LGBMClassifier¶
The converter is implemented in onnxmltools: onnxmltools…LightGbm.py. and the shape calculator: onnxmltools…Classifier.py.
update_registered_converter(
LGBMClassifier,
"LightGbmLGBMClassifier",
calculate_linear_classifier_output_shapes,
convert_lightgbm,
options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)
Convert again¶
Compare the predictions¶
Predictions with LightGbm.
predict [1 2 1 0 1]
predict_proba [[0.25335584 0.45934348 0.28730068]]
Predictions with onnxruntime.
sess = rt.InferenceSession("pipeline_lightgbm.onnx", providers=["CPUExecutionProvider"])
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])
predict [1 2 1 0 1]
predict_proba [{0: 0.25335583090782166, 1: 0.45934349298477173, 2: 0.287300705909729}]
Total running time of the script: (0 minutes 0.060 seconds)