Note
Go to the end to download the full example code.
Intermediate results and investigation¶
There are many reasons why a user wants more than using the converted model into ONNX. Intermediate results may be needed, the output of every node in the graph. The ONNX may need to be altered to remove some nodes. Transfer learning is usually removing the last layers of a deep neural network. Another reaason is debugging. It often happens that the runtime fails to compute the predictions due to a shape mismatch. Then it is useful the get the shape of every intermediate result. This example looks into two ways of doing it.
Look into pipeline steps¶
The first way is a tricky one: it overloads methods transform, predict and predict_proba to keep a copy of inputs and outputs. It then goes through every step of the pipeline. If the pipeline has n steps, it converts the pipeline with step 1, then the pipeline with steps 1, 2, then 1, 2, 3…
import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType
The pipeline.
data = load_iris()
X = data.data
pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
pipe.fit(X)
The function goes through every step, overloads the methods transform and returns an ONNX graph for every step.
steps = collect_intermediate_steps(
pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17
)
We call method transform to population the cache the overloaded methods transform keeps.
pipe.transform(X)
array([[4.00404832, 0.21295824, 3.15861505],
[4.05055769, 0.99604549, 2.72563625],
[4.22040251, 0.65198444, 3.02188403],
[4.22860026, 0.9034561 , 2.93043986],
[4.12353003, 0.40215457, 3.33653691],
[3.89643029, 1.21154793, 3.52936423],
[4.2374443 , 0.50244932, 3.19234391],
[3.99197553, 0.09132468, 3.03242342],
[4.4445734 , 1.42174651, 2.9795537 ],
[4.08705397, 0.78993078, 2.84221713],
[3.92610748, 0.78999385, 3.3507236 ],
[4.09865843, 0.27618123, 3.09168785],
[4.19718995, 1.03497888, 2.85719428],
[4.66454355, 1.33482453, 3.26547013],
[4.13826195, 1.63865558, 3.90871872],
[4.47633229, 2.39898792, 4.51414747],
[4.02762963, 1.20748818, 3.63475229],
[3.92839122, 0.21618828, 3.09288714],
[3.72388908, 1.20986655, 3.36736664],
[4.10521298, 0.86706182, 3.53103908],
[3.67990695, 0.50401564, 2.8436663 ],
[3.95222508, 0.66826437, 3.30977167],
[4.52523323, 0.68658071, 3.63034505],
[3.60185594, 0.47945627, 2.59973228],
[4.00845791, 0.36345425, 3.00678098],
[3.91688379, 0.99023912, 2.60615351],
[3.80966594, 0.22683089, 2.86756462],
[3.90931811, 0.2947186 , 3.0958103 ],
[3.89828815, 0.25361098, 2.99275191],
[4.12581898, 0.65019824, 2.92503544],
[4.04810077, 0.80138328, 2.78137328],
[3.58928575, 0.52309257, 2.76837135],
[4.49874494, 1.57658655, 4.14390673],
[4.43563509, 1.87652483, 4.2489329 ],
[4.008642 , 0.76858489, 2.76272437],
[4.04525625, 0.54896332, 2.9103173 ],
[3.81211172, 0.63079314, 3.09001251],
[4.26421417, 0.45982568, 3.44057865],
[4.45456872, 1.2336976 , 3.06034971],
[3.92683189, 0.14580827, 2.99422852],
[4.02712265, 0.20261743, 3.16142101],
[4.69480008, 2.67055552, 2.9575648 ],
[4.4496996 , 0.90927099, 3.20355969],
[3.71964918, 0.50081008, 2.89721622],
[3.91143692, 0.92159916, 3.37471011],
[4.04740147, 1.01946042, 2.70316642],
[4.14683513, 0.86953764, 3.56280964],
[4.26327469, 0.72275914, 3.04646993],
[3.98021229, 0.72324305, 3.37186092],
[3.99446269, 0.30295342, 2.94518173],
[0.9452659 , 3.43619989, 1.8639233 ],
[1.00829443, 2.97232682, 1.38933168],
[0.73653572, 3.51850037, 1.6428166 ],
[2.76204203, 3.33264308, 1.00264343],
[1.16604995, 3.35747592, 0.86560047],
[1.86711784, 2.77550662, 0.3750882 ],
[1.00955989, 3.01808184, 1.56489146],
[3.3697155 , 2.77360088, 1.55619573],
[1.18358725, 3.21148368, 1.08067281],
[2.48285941, 2.66294828, 0.82637993],
[3.79967007, 3.62389817, 2.01281316],
[1.50054672, 2.70011145, 0.76353654],
[2.80438695, 3.53658932, 1.27727048],
[1.34023352, 2.98813829, 0.62868121],
[2.09655735, 2.32311723, 0.77087912],
[1.00633966, 3.14311522, 1.43272989],
[1.71909321, 2.68234835, 0.80192 ],
[2.17926627, 2.63954211, 0.60569829],
[2.40214871, 3.97369206, 1.18764767],
[2.52511757, 2.87494798, 0.727372 ],
[1.21113562, 3.03853641, 1.31653995],
[1.68291281, 2.8022861 , 0.52313867],
[1.71597913, 3.68305664, 0.75211692],
[1.59856561, 2.96833851, 0.55292557],
[1.33753092, 2.9760862 , 0.87815407],
[1.06462905, 3.13002382, 1.19061026],
[1.13996294, 3.56679427, 1.22441299],
[0.5652633 , 3.5903606 , 1.37258261],
[1.39763754, 2.93839428, 0.56006248],
[2.49518379, 2.58203512, 0.81289907],
[2.75025306, 2.99796537, 0.94324481],
[2.82866407, 2.92597852, 1.03283946],
[2.08201734, 2.68907313, 0.4343386 ],
[1.48418961, 3.42215998, 0.48873673],
[1.92943813, 2.62771445, 0.91606802],
[1.40011111, 2.75915071, 1.69140864],
[0.79992473, 3.30075052, 1.44311693],
[2.2708714 , 3.73017167, 1.05036852],
[1.91690629, 2.37943811, 0.83618809],
[2.47017911, 2.98789866, 0.6470029 ],
[2.32571939, 2.89079656, 0.53979211],
[1.29304411, 2.86642713, 0.81855214],
[2.17526444, 2.86642575, 0.43194777],
[3.40973541, 2.96966239, 1.58383257],
[2.10849001, 2.77003779, 0.3618706 ],
[1.87076527, 2.38255534, 0.83187956],
[1.85116384, 2.55559903, 0.58147273],
[1.44451588, 2.8455521 , 0.70529895],
[3.11774537, 2.56987887, 1.34329146],
[1.94990512, 2.64007308, 0.41481694],
[1.04248866, 4.24274589, 2.26819164],
[1.57935402, 3.57067982, 0.72581017],
[0.52274684, 4.44150237, 2.09231844],
[0.83298461, 3.69480186, 1.12321156],
[0.5678145 , 4.11613683, 1.68255837],
[1.1830756 , 5.03326801, 2.72592116],
[2.8024351 , 3.3503222 , 1.25267619],
[0.93117407, 4.577021 , 2.18852343],
[1.46246781, 4.363498 , 1.45283591],
[1.4207266 , 4.79334275, 3.18264007],
[0.47962495, 3.62749566, 1.67405555],
[1.09881086, 3.89360823, 1.04698204],
[0.31830999, 4.1132966 , 1.75049044],
[1.98175664, 3.82688169, 0.92293569],
[1.54698303, 3.91538879, 1.35721732],
[0.68407345, 3.89835633, 1.86138575],
[0.52205472, 3.70128288, 1.34561415],
[2.03678461, 5.18341242, 3.80620352],
[1.84250874, 5.58136629, 2.90217633],
[2.43634558, 4.02615768, 1.16636059],
[0.48150581, 4.31907679, 2.22297775],
[1.67578773, 3.4288432 , 0.88685031],
[1.47096547, 5.19031307, 2.72431414],
[1.22329554, 3.64273089, 0.79101156],
[0.47109224, 4.00723617, 2.10999425],
[0.62558995, 4.2637671 , 2.28591141],
[1.14490402, 3.45930032, 0.74392898],
[0.99645552, 3.27575645, 0.98053107],
[0.90181942, 4.05342943, 1.3282425 ],
[0.76242411, 4.1585729 , 1.98849304],
[1.08628479, 4.71100584, 2.22822113],
[2.10967488, 5.12224641, 3.84302072],
[0.93357383, 4.13401784, 1.41836425],
[1.17526973, 3.39830644, 0.74517066],
[1.66051938, 3.63719075, 0.76558228],
[1.23742547, 5.08776655, 2.80545775],
[1.04697429, 4.00416552, 2.26945032],
[0.55013293, 3.58815834, 1.42313566],
[1.12188023, 3.19454679, 0.93290167],
[0.20983625, 4.09907253, 1.92136662],
[0.5691276 , 4.28416057, 2.02737038],
[0.49810802, 4.17402084, 2.01513279],
[1.57935402, 3.57067982, 0.72581017],
[0.50497262, 4.32128686, 2.19577242],
[0.81423561, 4.3480018 , 2.37699732],
[0.55018391, 4.1240495 , 1.77340222],
[1.58648502, 3.97564407, 0.98294137],
[0.49931367, 3.7539635 , 1.39731191],
[1.06536484, 3.7969924 , 2.13822884],
[1.18287527, 3.25638099, 0.96885287]])
We compute every step and compare ONNX and scikit-learn outputs.
for step in steps:
print("----------------------------")
print(step["model"])
onnx_step = step["onnx_step"]
sess = InferenceSession(
onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
onnx_output = onnx_outputs[-1]
skl_outputs = step["model"]._debug.outputs["transform"]
# comparison
diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
print("difference", diff)
# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
#
----------------------------
StandardScaler()
difference 4.799262827148709e-07
----------------------------
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06
Python runtime to look into every node¶
The python runtime may be useful to easily look into every node of the ONNX graph. This option can be used to check when the computation fails due to nan values or a dimension mismatch.
onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)
oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615 ],
[4.050557 , 0.99604493, 2.7256362 ]], dtype=float32)]
And to get a sense of the intermediate results.
oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
+C Ad_Addcst: float32:(3,) in [1.0065230131149292, 5.035177230834961]
+C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1674340963363647]
+C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
+I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
+ variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
+ Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
+ Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
+ Ge_Y0: float32:(2, 3) in [-10.366023063659668, 8.10552978515625]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
+ Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.956035614013672]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
+ Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.407014846801758]
ArgMin(Ad_C0) -> label
+ label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
+ scores: float32:(2, 3) in [0.2129589319229126, 4.0505571365356445]
[array([1, 1]), array([[4.0040483 , 0.21295893, 3.158615 ],
[4.050557 , 0.99604493, 2.7256362 ]], dtype=float32)]
Total running time of the script: (0 minutes 0.115 seconds)