.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_tutorial/plot_usparse_xgboost.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_tutorial_plot_usparse_xgboost.py: .. _example-sparse-tfidf: TfIdf and sparse matrices ========================= .. index:: xgboost, lightgbm, sparse, ensemble `TfidfVectorizer `_ usually creates sparse data. If the data is sparse enough, matrices usually stays as sparse all along the pipeline until the predictor is trained. Sparse matrices do not consider null and missing values as they are not present in the datasets. Because some predictors do the difference, this ambiguity may introduces discrepencies when converter into ONNX. This example looks into several configurations. Imports, setups +++++++++++++++ All imports. It also registered onnx converters for :epkg:`xgboost` and *lightgbm*. .. GENERATED FROM PYTHON SOURCE LINES 26-69 .. code-block:: Python import warnings import numpy import pandas import onnxruntime as rt from tqdm import tqdm from sklearn.compose import ColumnTransformer from sklearn.datasets import load_iris from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.ensemble import RandomForestClassifier try: from sklearn.ensemble import HistGradientBoostingClassifier except ImportError: HistGradientBoostingClassifier = None from xgboost import XGBClassifier from lightgbm import LGBMClassifier from skl2onnx.common.data_types import FloatTensorType, StringTensorType from skl2onnx import to_onnx, update_registered_converter from skl2onnx.sklapi import CastTransformer, ReplaceTransformer from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm update_registered_converter( XGBClassifier, "XGBoostXGBClassifier", calculate_linear_classifier_output_shapes, convert_xgboost, options={"nocl": [True, False], "zipmap": [True, False, "columns"]}, ) update_registered_converter( LGBMClassifier, "LightGbmLGBMClassifier", calculate_linear_classifier_output_shapes, convert_lightgbm, options={"nocl": [True, False], "zipmap": [True, False]}, ) .. GENERATED FROM PYTHON SOURCE LINES 70-74 Artificial datasets +++++++++++++++++++++++++++ Iris + a text column. .. GENERATED FROM PYTHON SOURCE LINES 74-92 .. code-block:: Python cst = ["class zero", "class one", "class two"] data = load_iris() X = data.data[:, :2] y = data.target df = pandas.DataFrame(X) df.columns = [f"c{c}" for c in df.columns] df["text"] = [cst[i] for i in y] ind = numpy.arange(X.shape[0]) numpy.random.shuffle(ind) X = X[ind, :].copy() y = y[ind].copy() .. GENERATED FROM PYTHON SOURCE LINES 93-99 Train ensemble after sparse +++++++++++++++++++++++++++ The example use the Iris datasets with artifical text datasets preprocessed with a tf-idf. `sparse_threshold=1.` avoids sparse matrices to be converted into dense matrices. .. GENERATED FROM PYTHON SOURCE LINES 99-233 .. code-block:: Python def make_pipelines( df_train, y_train, models=None, sparse_threshold=1.0, replace_nan=False, insert_replace=False, ): if models is None: models = [ RandomForestClassifier, HistGradientBoostingClassifier, XGBClassifier, LGBMClassifier, ] models = [_ for _ in models if _ is not None] pipes = [] for model in tqdm(models): if model == HistGradientBoostingClassifier: kwargs = dict(max_iter=5) elif model == XGBClassifier: kwargs = dict(n_estimators=5, use_label_encoder=False) else: kwargs = dict(n_estimators=5) if insert_replace: pipe = Pipeline( [ ( "union", ColumnTransformer( [ ("scale1", StandardScaler(), [0, 1]), ( "subject", Pipeline( [ ("count", CountVectorizer()), ("tfidf", TfidfTransformer()), ("repl", ReplaceTransformer()), ] ), "text", ), ], sparse_threshold=sparse_threshold, ), ), ("cast", CastTransformer()), ("cls", model(max_depth=3, **kwargs)), ] ) else: pipe = Pipeline( [ ( "union", ColumnTransformer( [ ("scale1", StandardScaler(), [0, 1]), ( "subject", Pipeline( [ ("count", CountVectorizer()), ("tfidf", TfidfTransformer()), ] ), "text", ), ], sparse_threshold=sparse_threshold, ), ), ("cast", CastTransformer()), ("cls", model(max_depth=3, **kwargs)), ] ) try: pipe.fit(df_train, y_train) except TypeError as e: obs = dict(model=model.__name__, pipe=pipe, error=e, model_onnx=None) pipes.append(obs) continue options = {model: {"zipmap": False}} if replace_nan: options[TfidfTransformer] = {"nan": True} # convert with warnings.catch_warnings(record=False): warnings.simplefilter("ignore", (FutureWarning, UserWarning)) model_onnx = to_onnx( pipe, initial_types=[ ("input", FloatTensorType([None, 2])), ("text", StringTensorType([None, 1])), ], target_opset={"": 12, "ai.onnx.ml": 2}, options=options, ) with open("model.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) sess = rt.InferenceSession( model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] ) inputs = { "input": df[["c0", "c1"]].values.astype(numpy.float32), "text": df[["text"]].values, } pred_onx = sess.run(None, inputs) diff = numpy.abs(pred_onx[1].ravel() - pipe.predict_proba(df).ravel()).sum() obs = dict( model=model.__name__, discrepencies=diff, model_onnx=model_onnx, pipe=pipe ) pipes.append(obs) return pipes data_sparse = make_pipelines(df, y) stat = pandas.DataFrame(data_sparse).drop(["model_onnx", "pipe"], axis=1) if "error" in stat.columns: print(stat.drop("error", axis=1)) stat .. rst-class:: sphx-glr-script-out .. code-block:: pytb Traceback (most recent call last): File "/home/xadupre/github/sklearn-onnx/docs/tutorial/plot_usparse_xgboost.py", line 227, in data_sparse = make_pipelines(df, y) ^^^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/github/sklearn-onnx/docs/tutorial/plot_usparse_xgboost.py", line 217, in make_pipelines diff = numpy.abs(pred_onx[1].ravel() - pipe.predict_proba(df).ravel()).sum() ^^^^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/pipeline.py", line 896, in predict_proba with _raise_or_warn_if_not_fitted(self): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/lib/python3.12/contextlib.py", line 144, in __exit__ next(self.gen) File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/pipeline.py", line 60, in _raise_or_warn_if_not_fitted check_is_fitted(estimator) File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py", line 1756, in check_is_fitted if not _is_fitted(estimator, attributes, all_or_any): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py", line 1665, in _is_fitted return estimator.__sklearn_is_fitted__() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/pipeline.py", line 1310, in __sklearn_is_fitted__ check_is_fitted(last_step) File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/validation.py", line 1751, in check_is_fitted tags = get_tags(estimator) ^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/utils/_tags.py", line 405, in get_tags sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xadupre/vv/this312/lib/python3.12/site-packages/sklearn/base.py", line 540, in __sklearn_tags__ tags = super().__sklearn_tags__() ^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'super' object has no attribute '__sklearn_tags__' .. GENERATED FROM PYTHON SOURCE LINES 234-240 Sparse data hurts. Dense data ++++++++++ Let's replace sparse data with dense by using `sparse_threshold=0.` .. GENERATED FROM PYTHON SOURCE LINES 240-248 .. code-block:: Python data_dense = make_pipelines(df, y, sparse_threshold=0.0) stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1) if "error" in stat.columns: print(stat.drop("error", axis=1)) stat .. GENERATED FROM PYTHON SOURCE LINES 249-251 This is much better. Let's compare how the preprocessing applies on the data. .. GENERATED FROM PYTHON SOURCE LINES 251-258 .. code-block:: Python print("sparse") print(data_sparse[-1]["pipe"].steps[0][-1].transform(df)[:2]) print() print("dense") print(data_dense[-1]["pipe"].steps[0][-1].transform(df)[:2]) .. GENERATED FROM PYTHON SOURCE LINES 259-278 This shows `RandomForestClassifier `_, `XGBClassifier `_ do not process the same way sparse and dense matrix as opposed to `LGBMClassifier `_. And `HistGradientBoostingClassifier `_ fails. Dense data with nan +++++++++++++++++++ Let's keep sparse data in the scikit-learn pipeline but replace null values by nan in the onnx graph. .. GENERATED FROM PYTHON SOURCE LINES 278-286 .. code-block:: Python data_dense = make_pipelines(df, y, sparse_threshold=1.0, replace_nan=True) stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1) if "error" in stat.columns: print(stat.drop("error", axis=1)) stat .. GENERATED FROM PYTHON SOURCE LINES 287-296 Dense, 0 replaced by nan ++++++++++++++++++++++++ Instead of using a specific options to replace null values into nan values, a custom transformer called ReplaceTransformer is explicitely inserted into the pipeline. A new converter is added to the list of supported models. It is equivalent to the previous options except it is more explicit. .. GENERATED FROM PYTHON SOURCE LINES 296-305 .. code-block:: Python data_dense = make_pipelines( df, y, sparse_threshold=1.0, replace_nan=False, insert_replace=True ) stat = pandas.DataFrame(data_dense).drop(["model_onnx", "pipe"], axis=1) if "error" in stat.columns: print(stat.drop("error", axis=1)) stat .. GENERATED FROM PYTHON SOURCE LINES 306-312 Conclusion ++++++++++ Unless dense arrays are used, because *onnxruntime* ONNX does not support sparse yet, the conversion needs to be tuned depending on the model which follows the TfIdf preprocessing. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.258 seconds) .. _sphx_glr_download_auto_tutorial_plot_usparse_xgboost.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_usparse_xgboost.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_usparse_xgboost.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_usparse_xgboost.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_