Convert a pipeline with a LightGBM regressor

The discrepancies observed when using float and TreeEnsemble operator (see Issues when switching to float) explains why the converter for LGBMRegressor may introduce significant discrepancies even when it is used with float tensors.

Library lightgbm is implemented with double. A random forest regressor with multiple trees computes its prediction by adding the prediction of every tree. After being converting into ONNX, this summation becomes \left[\sum\right]_{i=1}^F float(T_i(x)), where F is the number of trees in the forest, T_i(x) the output of tree i and \left[\sum\right] a float addition. The discrepancy can be expressed as D(x) = |\left[\sum\right]_{i=1}^F float(T_i(x)) -
\sum_{i=1}^F T_i(x)|. This grows with the number of trees in the forest.

To reduce the impact, an option was added to split the node TreeEnsembleRegressor into multiple ones and to do a summation with double this time. If we assume the node if split into a nodes, the discrepancies then become D'(x) = |\sum_{k=1}^a \left[\sum\right]_{i=1}^{F/a}
float(T_{ak + i}(x)) - \sum_{i=1}^F T_i(x)|.

Train a LGBMRegressor

import packaging.version as pv
import warnings
import timeit
import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from tqdm import tqdm
from lightgbm import LGBMRegressor
from onnxruntime import InferenceSession
from skl2onnx import to_onnx, update_registered_converter
from skl2onnx.common.shape_calculator import (
    calculate_linear_regressor_output_shapes,
)
from onnxmltools import __version__ as oml_version
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
    convert_lightgbm,
)


N = 1000
X = numpy.random.randn(N, 20)
y = numpy.random.randn(N) + numpy.random.randn(N) * 100 * numpy.random.randint(
    0, 1, 1000
)

reg = LGBMRegressor(n_estimators=1000)
reg.fit(X, y)
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000475 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 5100
[LightGBM] [Info] Number of data points in the train set: 1000, number of used features: 20
[LightGBM] [Info] Start training from score 0.033823
LGBMRegressor(n_estimators=1000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Register the converter for LGBMClassifier

The converter is implemented in onnxmltools: onnxmltools…LightGbm.py. and the shape calculator: onnxmltools…Regressor.py.

def skl2onnx_convert_lightgbm(scope, operator, container):
    options = scope.get_options(operator.raw_operator)
    if "split" in options:
        if pv.Version(oml_version) < pv.Version("1.9.2"):
            warnings.warn(
                "Option split was released in version 1.9.2 but %s is "
                "installed. It will be ignored." % oml_version,
                stacklevel=0,
            )
        operator.split = options["split"]
    else:
        operator.split = None
    convert_lightgbm(scope, operator, container)


update_registered_converter(
    LGBMRegressor,
    "LightGbmLGBMRegressor",
    calculate_linear_regressor_output_shapes,
    skl2onnx_convert_lightgbm,
    options={"split": None},
)

Convert

We convert the same model following the two scenarios, one single TreeEnsembleRegressor node, or more. split parameter is the number of trees per node TreeEnsembleRegressor.

model_onnx = to_onnx(
    reg, X[:1].astype(numpy.float32), target_opset={"": 14, "ai.onnx.ml": 2}
)
model_onnx_split = to_onnx(
    reg,
    X[:1].astype(numpy.float32),
    target_opset={"": 14, "ai.onnx.ml": 2},
    options={"split": 100},
)

Discrepancies

sess = InferenceSession(
    model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
sess_split = InferenceSession(
    model_onnx_split.SerializeToString(), providers=["CPUExecutionProvider"]
)

X32 = X.astype(numpy.float32)
expected = reg.predict(X32)
got = sess.run(None, {"X": X32})[0].ravel()
got_split = sess_split.run(None, {"X": X32})[0].ravel()

disp = numpy.abs(got - expected).sum()
disp_split = numpy.abs(got_split - expected).sum()

print("sum of discrepancies 1 node", disp)
print("sum of discrepancies split node", disp_split, "ratio:", disp / disp_split)
sum of discrepancies 1 node 0.00010992635655524084
sum of discrepancies split node 4.182445361007939e-05 ratio: 2.6282795605666776

The sum of the discrepancies were reduced 4, 5 times. The maximum is much better too.

disc = numpy.abs(got - expected).max()
disc_split = numpy.abs(got_split - expected).max()

print("max discrepancies 1 node", disc)
print("max discrepancies split node", disc_split, "ratio:", disc / disc_split)
max discrepancies 1 node 9.335552881850617e-07
max discrepancies split node 2.9892391495423e-07 ratio: 3.1230531967574557

Processing time

The processing time is slower but not much.

print(
    "processing time no split",
    timeit.timeit(lambda: sess.run(None, {"X": X32})[0], number=150),
)
print(
    "processing time split",
    timeit.timeit(lambda: sess_split.run(None, {"X": X32})[0], number=150),
)
processing time no split 1.0977208300027996
processing time split 1.2429911670005822

Split influence

Let’s see how the sum of the discrepancies moves against the parameter split.

res = []
for i in tqdm([*range(20, 170, 20), 200, 300, 400, 500]):
    model_onnx_split = to_onnx(
        reg,
        X[:1].astype(numpy.float32),
        target_opset={"": 14, "ai.onnx.ml": 2},
        options={"split": i},
    )
    sess_split = InferenceSession(
        model_onnx_split.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    got_split = sess_split.run(None, {"X": X32})[0].ravel()
    disc_split = numpy.abs(got_split - expected).max()
    res.append(dict(split=i, disc=disc_split))

df = DataFrame(res).set_index("split")
df["baseline"] = disc
print(df)
  0%|          | 0/12 [00:00<?, ?it/s]
  8%|▊         | 1/12 [00:01<00:12,  1.11s/it]
 17%|█▋        | 2/12 [00:02<00:10,  1.09s/it]
 25%|██▌       | 3/12 [00:03<00:10,  1.15s/it]
 33%|███▎      | 4/12 [00:04<00:08,  1.11s/it]
 42%|████▏     | 5/12 [00:05<00:07,  1.12s/it]
 50%|█████     | 6/12 [00:06<00:06,  1.12s/it]
 58%|█████▊    | 7/12 [00:08<00:06,  1.31s/it]
 67%|██████▋   | 8/12 [00:09<00:05,  1.33s/it]
 75%|███████▌  | 9/12 [00:11<00:04,  1.36s/it]
 83%|████████▎ | 10/12 [00:12<00:02,  1.28s/it]
 92%|█████████▏| 11/12 [00:13<00:01,  1.31s/it]
100%|██████████| 12/12 [00:14<00:00,  1.24s/it]
100%|██████████| 12/12 [00:14<00:00,  1.23s/it]
               disc      baseline
split
20     1.955193e-07  9.335553e-07
40     3.277593e-07  9.335553e-07
60     3.374452e-07  9.335553e-07
80     3.948104e-07  9.335553e-07
100    2.989239e-07  9.335553e-07
120    2.703531e-07  9.335553e-07
140    3.906515e-07  9.335553e-07
160    3.629678e-07  9.335553e-07
200    4.123835e-07  9.335553e-07
300    6.290701e-07  9.335553e-07
400    5.661779e-07  9.335553e-07
500    6.868547e-07  9.335553e-07

Graph.

_, ax = plt.subplots(1, 1)
df.plot(
    title="Sum of discrepancies against split\nsplit = number of tree per node",
    ax=ax,
)

# plt.show()
Sum of discrepancies against split split = number of tree per node

Total running time of the script: (0 minutes 19.930 seconds)

Gallery generated by Sphinx-Gallery