TfIdf and sparse matrices

TfidfVectorizer usually creates sparse data. If the data is sparse enough, matrices usually stays as sparse all along the pipeline until the predictor is trained. Sparse matrices do not consider null and missing values as they are not present in the datasets. Because some predictors do the difference, this ambiguity may introduces discrepencies when converter into ONNX. This example looks into several configurations.

Imports, setups

All imports. It also registered onnx converters for xgboost and lightgbm.

import warnings
import numpy
import pandas
import onnxruntime as rt
from tqdm import tqdm
from sklearn.compose import ColumnTransformer
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.ensemble import RandomForestClassifier

try:
    from sklearn.ensemble import HistGradientBoostingClassifier
except ImportError:
    HistGradientBoostingClassifier = None
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from skl2onnx.common.data_types import FloatTensorType, StringTensorType
from skl2onnx import to_onnx, update_registered_converter
from skl2onnx.sklapi import CastTransformer, ReplaceTransformer
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes
from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm


update_registered_converter(
    XGBClassifier,
    "XGBoostXGBClassifier",
    calculate_linear_classifier_output_shapes,
    convert_xgboost,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)
update_registered_converter(
    LGBMClassifier,
    "LightGbmLGBMClassifier",
    calculate_linear_classifier_output_shapes,
    convert_lightgbm,
    options={"nocl": [True, False], "zipmap": [True, False]},
)

Artificial datasets

Iris + a text column.

cst = ["class zero", "class one", "class two"]

data = load_iris()
X = data.data[:, :2]
y = data.target

df = pandas.DataFrame(X)
df.columns = [f"c{c}" for c in df.columns]
df["text"] = [cst[i] for i in y]


ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

Train ensemble after sparse

The example use the Iris datasets with artifical text datasets preprocessed with a tf-idf. sparse_threshold=1. avoids sparse matrices to be converted into dense matrices.

def make_pipelines(
    df_train,
    y_train,
    models=None,
    sparse_threshold=1.0,
    replace_nan=False,
    insert_replace=False,
):
    if models is None:
        models = [
            RandomForestClassifier,
            HistGradientBoostingClassifier,
            XGBClassifier,
            LGBMClassifier,
        ]
    models = [_ for _ in models if _ is not None]

    pipes = []
    for model in tqdm(models):
        if model == HistGradientBoostingClassifier:
            kwargs = dict(max_iter=5)
        elif model == XGBClassifier:
            kwargs = dict(n_estimators=5, use_label_encoder=False)
        else:
            kwargs = dict(n_estimators=5)

        if insert_replace:
            pipe = Pipeline(
                [
                    (
                        "union",
                        ColumnTransformer(
                            [
                                ("scale1", StandardScaler(), [0, 1]),
                                (
                                    "subject",
                                    Pipeline(
                                        [
                                            ("count", CountVectorizer()),
                                            ("tfidf", TfidfTransformer()),
                                            ("repl", ReplaceTransformer()),
                                        ]
                                    ),
                                    "text",
                                ),
                            ],
                            sparse_threshold=sparse_threshold,
                        ),
                    ),
                    ("cast", CastTransformer()),
                    ("cls", model(max_depth=3, **kwargs)),
                ]
            )
        else:
            pipe = Pipeline(
                [
                    (
                        "union",
                        ColumnTransformer(
                            [
                                ("scale1", StandardScaler(), [0, 1]),
                                (
                                    "subject",
                                    Pipeline(
                                        [
                                            ("count", CountVectorizer()),
                                            ("tfidf", TfidfTransformer()),
                                        ]
                                    ),
                                    "text",
                                ),
                            ],
                            sparse_threshold=sparse_threshold,
                        ),
                    ),
                    ("cast", CastTransformer()),
                    ("cls", model(max_depth=3, **kwargs)),
                ]
            )

        try:
            pipe.fit(df_train, y_train)
        except TypeError as e:
            obs = dict(model=model.__name__, pipe=pipe, error=e, model_onnx=None)
            pipes.append(obs)
            continue

        options = {model: {"zipmap": False}}
        if replace_nan:
            options[TfidfTransformer] = {"nan": True}

        # convert
        with warnings.catch_warnings(record=False):
            warnings.simplefilter("ignore", (FutureWarning, UserWarning))
            model_onnx = to_onnx(
                pipe,
                initial_types=[
                    ("input", FloatTensorType([None, 2])),
                    ("text", StringTensorType([None, 1])),
                ],
                target_opset={"": 12, "ai.onnx.ml": 2},
                options=options,
            )

        with open("model.onnx", "wb") as f:
            f.write(model_onnx.SerializeToString())

        sess = rt.InferenceSession(
            model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
        )
        inputs = {
            "input": df[["c0", "c1"]].values.astype(numpy.float32),
            "text": df[["text"]].values,
        }
        pred_onx = sess.run(None, inputs)

        diff = numpy.abs(pred_onx[1].ravel() - pipe.predict_proba(df).ravel()).sum()

        obs = dict(
            model=model.__name__, discrepencies=diff, model_onnx=model_onnx, pipe=pipe
        )
        pipes.append(obs)

    return pipes


data_sparse = make_pipelines(df, y)
stat = pandas.DataFrame(data_sparse).drop(["model_onnx", "pipe"], axis=1)
if "error" in stat.columns:
    print(stat.drop("error", axis=1))
stat
  0%|          | 0/4 [00:00<?, ?it/s]/home/xadupre/vv/this312/lib/python3.12/site-packages/xgboost/training.py:183: UserWarning: [09:45:35] WARNING: /workspace/src/learner.cc:738:
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)

 75%|███████▌  | 3/4 [00:00<00:00,  3.18it/s][LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000728 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 53
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 5
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py:2735: UserWarning: X does not have valid feature names, but LGBMClassifier was fitted with feature names
  warnings.warn(

100%|██████████| 4/4 [00:01<00:00,  3.84it/s]
                            model  discrepencies
0          RandomForestClassifier       0.947052
1  HistGradientBoostingClassifier            NaN
2                   XGBClassifier      15.196619
3                  LGBMClassifier       0.000009
model discrepencies error
0 RandomForestClassifier 0.947052 NaN
1 HistGradientBoostingClassifier NaN Sparse data was passed for X, but dense data i...
2 XGBClassifier 15.196619 NaN
3 LGBMClassifier 0.000009 NaN


Sparse data hurts.

Dense data

Let’s replace sparse data with dense by using sparse_threshold=0.

data_dense = make_pipelines(df, y, sparse_threshold=0.0)
stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1)
if "error" in stat.columns:
    print(stat.drop("error", axis=1))
stat
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00,  2.78it/s]/home/xadupre/vv/this312/lib/python3.12/site-packages/xgboost/training.py:183: UserWarning: [09:45:37] WARNING: /workspace/src/learner.cc:738:
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)

 75%|███████▌  | 3/4 [00:01<00:00,  1.81it/s][LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.033149 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 53
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 5
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py:2735: UserWarning: X does not have valid feature names, but LGBMClassifier was fitted with feature names
  warnings.warn(

100%|██████████| 4/4 [00:01<00:00,  2.58it/s]
100%|██████████| 4/4 [00:01<00:00,  2.42it/s]
model discrepencies
0 RandomForestClassifier 0.940913
1 HistGradientBoostingClassifier 0.000005
2 XGBClassifier 2.899390
3 LGBMClassifier 0.000009


This is much better. Let’s compare how the preprocessing applies on the data.

print("sparse")
print(data_sparse[-1]["pipe"].steps[0][-1].transform(df)[:2])
print()
print("dense")
print(data_dense[-1]["pipe"].steps[0][-1].transform(df)[:2])
sparse
<Compressed Sparse Row sparse matrix of dtype 'float64'
        with 8 stored elements and shape (2, 6)>
  Coords        Values
  (0, 0)        -0.9006811702978088
  (0, 1)        1.019004351971607
  (0, 2)        0.4323732931220851
  (0, 5)        0.9016947018779491
  (1, 0)        -1.1430169111851105
  (1, 1)        -0.13197947932162468
  (1, 2)        0.4323732931220851
  (1, 5)        0.9016947018779491

dense
[[-0.90068117  1.01900435  0.43237329  0.          0.          0.9016947 ]
 [-1.14301691 -0.13197948  0.43237329  0.          0.          0.9016947 ]]

This shows RandomForestClassifier, XGBClassifier do not process the same way sparse and dense matrix as opposed to LGBMClassifier. And HistGradientBoostingClassifier fails.

Dense data with nan

Let’s keep sparse data in the scikit-learn pipeline but replace null values by nan in the onnx graph.

data_dense = make_pipelines(df, y, sparse_threshold=1.0, replace_nan=True)
stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1)
if "error" in stat.columns:
    print(stat.drop("error", axis=1))
stat
  0%|          | 0/4 [00:00<?, ?it/s]/home/xadupre/vv/this312/lib/python3.12/site-packages/xgboost/training.py:183: UserWarning: [09:45:38] WARNING: /workspace/src/learner.cc:738:
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)

 75%|███████▌  | 3/4 [00:00<00:00,  3.21it/s][LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.012050 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 53
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 5
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py:2735: UserWarning: X does not have valid feature names, but LGBMClassifier was fitted with feature names
  warnings.warn(

100%|██████████| 4/4 [00:01<00:00,  4.06it/s]
100%|██████████| 4/4 [00:01<00:00,  3.83it/s]
                            model  discrepencies
0          RandomForestClassifier      24.634892
1  HistGradientBoostingClassifier            NaN
2                   XGBClassifier       2.899390
3                  LGBMClassifier       0.000009
model discrepencies error
0 RandomForestClassifier 24.634892 NaN
1 HistGradientBoostingClassifier NaN Sparse data was passed for X, but dense data i...
2 XGBClassifier 2.899390 NaN
3 LGBMClassifier 0.000009 NaN


Dense, 0 replaced by nan

Instead of using a specific options to replace null values into nan values, a custom transformer called ReplaceTransformer is explicitely inserted into the pipeline. A new converter is added to the list of supported models. It is equivalent to the previous options except it is more explicit.

data_dense = make_pipelines(
    df, y, sparse_threshold=1.0, replace_nan=False, insert_replace=True
)
stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1)
if "error" in stat.columns:
    print(stat.drop("error", axis=1))
stat
  0%|          | 0/4 [00:00<?, ?it/s]/home/xadupre/vv/this312/lib/python3.12/site-packages/xgboost/training.py:183: UserWarning: [09:45:39] WARNING: /workspace/src/learner.cc:738:
Parameters: { "use_label_encoder" } are not used.

  bst.update(dtrain, iteration=i, fobj=obj)

 75%|███████▌  | 3/4 [00:01<00:00,  2.67it/s][LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000045 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 53
[LightGBM] [Info] Number of data points in the train set: 150, number of used features: 5
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Info] Start training from score -1.098612
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py:2735: UserWarning: X does not have valid feature names, but LGBMClassifier was fitted with feature names
  warnings.warn(

100%|██████████| 4/4 [00:01<00:00,  3.30it/s]
                            model  discrepencies
0          RandomForestClassifier      41.288296
1  HistGradientBoostingClassifier            NaN
2                   XGBClassifier       2.899390
3                  LGBMClassifier       0.000009
model discrepencies error
0 RandomForestClassifier 41.288296 NaN
1 HistGradientBoostingClassifier NaN Sparse data was passed for X, but dense data i...
2 XGBClassifier 2.899390 NaN
3 LGBMClassifier 0.000009 NaN


Conclusion

Unless dense arrays are used, because onnxruntime ONNX does not support sparse yet, the conversion needs to be tuned depending on the model which follows the TfIdf preprocessing.

Total running time of the script: (0 minutes 4.999 seconds)

Gallery generated by Sphinx-Gallery