.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_cast_transformer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_cast_transformer.py: .. _l-cast_transformer: Discrepencies with StandardScaler ================================= A `StandardScaler `_ does a very basic scaling. The conversion in ONNX assumes that ``(x / y)`` is equivalent to ``x * ( 1 / y)`` but that's not true with float or double (see `Will the compiler optimize division into multiplication `_). Even if the difference is small, it may introduce discrepencies if the next step is a decision tree. One small difference and the decision follows another path in the tree. Let's see how to solve that issue. An example with fails +++++++++++++++++++++ This is not a typical example, it is build to make it fails based on the assumption ``(x / y)`` is usually different from ``x * ( 1 / y)`` on a computer. .. GENERATED FROM PYTHON SOURCE LINES 31-48 .. code-block:: Python import onnxruntime import onnx import os import math import numpy as np import matplotlib.pyplot as plt from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer from onnxruntime import InferenceSession from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.tree import DecisionTreeRegressor from skl2onnx.sklapi import CastTransformer from skl2onnx import to_onnx .. GENERATED FROM PYTHON SOURCE LINES 49-50 The weird data. .. GENERATED FROM PYTHON SOURCE LINES 50-61 .. code-block:: Python X, y = make_regression(10000, 10, random_state=3) X_train, X_test, y_train, _ = train_test_split(X, y, random_state=3) Xi_train, yi_train = X_train.copy(), y_train.copy() Xi_test = X_test.copy() for i in range(X.shape[1]): Xi_train[:, i] = (Xi_train[:, i] * math.pi * 2**i).astype(np.int64) Xi_test[:, i] = (Xi_test[:, i] * math.pi * 2**i).astype(np.int64) max_depth = 10 Xi_test = Xi_test.astype(np.float32) .. GENERATED FROM PYTHON SOURCE LINES 62-63 A simple model. .. GENERATED FROM PYTHON SOURCE LINES 63-70 .. code-block:: Python model1 = Pipeline( [("scaler", StandardScaler()), ("dt", DecisionTreeRegressor(max_depth=max_depth))] ) model1.fit(Xi_train, yi_train) exp1 = model1.predict(Xi_test) .. GENERATED FROM PYTHON SOURCE LINES 71-72 Conversion into ONNX. .. GENERATED FROM PYTHON SOURCE LINES 72-75 .. code-block:: Python onx1 = to_onnx(model1, X_train[:1].astype(np.float32), target_opset=15) sess1 = InferenceSession(onx1.SerializeToString(), providers=["CPUExecutionProvider"]) .. GENERATED FROM PYTHON SOURCE LINES 76-77 And the maximum difference. .. GENERATED FROM PYTHON SOURCE LINES 77-88 .. code-block:: Python got1 = sess1.run(None, {"X": Xi_test})[0] def maxdiff(a1, a2): d = np.abs(a1.ravel() - a2.ravel()) return d.max() md1 = maxdiff(exp1, got1) print(md1) .. rst-class:: sphx-glr-script-out .. code-block:: none 322.39065126389346 .. GENERATED FROM PYTHON SOURCE LINES 89-90 The graph. .. GENERATED FROM PYTHON SOURCE LINES 90-108 .. code-block:: Python pydot_graph = GetPydotGraph( onx1.graph, name=onx1.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("cast1.dot") os.system("dot -O -Gdpi=300 -Tpng cast1.dot") image = plt.imread("cast1.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_cast_transformer_001.png :alt: plot cast transformer :srcset: /auto_examples/images/sphx_glr_plot_cast_transformer_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 2536.5, 1707.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 109-123 New pipeline ++++++++++++ Fixing the conversion requires to replace ``(x * (1 / y)`` by ``(x / y)`` and this division must happen in double. By default, the *sklearn-onnx* assumes every computer should happen in float. `ONNX 1.7 specifications `_ does not support double scaling (input and output does, but not the parameters). The solution needs to change the conversion (remove node Scaler by using option `'div'`) and to use double by inserting an explicit Cast. .. GENERATED FROM PYTHON SOURCE LINES 123-149 .. code-block:: Python model2 = Pipeline( [ ("cast64", CastTransformer(dtype=np.float64)), ("scaler", StandardScaler()), ("cast", CastTransformer()), ("dt", DecisionTreeRegressor(max_depth=max_depth)), ] ) model2.fit(Xi_train, yi_train) exp2 = model2.predict(Xi_test) onx2 = to_onnx( model2, X_train[:1].astype(np.float32), options={StandardScaler: {"div": "div_cast"}}, target_opset=15, ) sess2 = InferenceSession(onx2.SerializeToString(), providers=["CPUExecutionProvider"]) got2 = sess2.run(None, {"X": Xi_test})[0] md2 = maxdiff(exp2, got2) print(md2) .. rst-class:: sphx-glr-script-out .. code-block:: none 2.9884569016758178e-05 .. GENERATED FROM PYTHON SOURCE LINES 150-151 The graph. .. GENERATED FROM PYTHON SOURCE LINES 151-169 .. code-block:: Python pydot_graph = GetPydotGraph( onx2.graph, name=onx2.graph.name, rankdir="TB", node_producer=GetOpNodeProducer( "docstring", color="yellow", fillcolor="yellow", style="filled" ), ) pydot_graph.write_dot("cast2.dot") os.system("dot -O -Gdpi=300 -Tpng cast2.dot") image = plt.imread("cast2.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_cast_transformer_002.png :alt: plot cast transformer :srcset: /auto_examples/images/sphx_glr_plot_cast_transformer_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 2536.5, 4171.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 170-171 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 171-181 .. code-block:: Python import sklearn # noqa print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) import skl2onnx # noqa print("onnx: ", onnx.__version__) print("onnxruntime: ", onnxruntime.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 1.26.4 scikit-learn: 1.6.dev0 onnx: 1.17.0 onnxruntime: 1.18.0+cu118 skl2onnx: 1.17.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.086 seconds) .. _sphx_glr_download_auto_examples_plot_cast_transformer.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cast_transformer.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cast_transformer.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_