Probabilities as a vector or as a ZipMap#

A classifier usually returns a matrix of probabilities. By default, sklearn-onnx converts that matrix into a list of dictionaries where each probabily is mapped to its class id or name. That mechanism retains the class names. This conversion increases the prediction time and is not always needed. Let’s see how to deactivate this behaviour on the Iris example.

Train a model and convert it#

from timeit import repeat
import numpy
import sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import onnx
import skl2onnx
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from sklearn.linear_model import LogisticRegression

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = LogisticRegression(max_iter=500)
clr.fit(X_train, y_train)
print(clr)

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
LogisticRegression(max_iter=500)

Output type#

Let’s confirm the output type of the probabilities is a list of dictionaries with onnxruntime.

sess = rt.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"float_input": X_test.astype(numpy.float32)})
print(res[1][:2])
print("probabilities type:", type(res[1]))
print("type for the first observations:", type(res[1][0]))
[{0: 9.76420778897591e-05, 1: 0.046403080224990845, 2: 0.9534992575645447}, {0: 0.018943823873996735, 1: 0.8978846669197083, 2: 0.083171546459198}]
probabilities type: <class 'list'>
type for the first observations: <class 'dict'>

Without ZipMap#

Let’s remove the ZipMap operator.

initial_type = [("float_input", FloatTensorType([None, 4]))]
options = {id(clr): {"zipmap": False}}
onx2 = convert_sklearn(
    clr, initial_types=initial_type, options=options, target_opset=12
)

sess2 = rt.InferenceSession(
    onx2.SerializeToString(), providers=["CPUExecutionProvider"]
)
res2 = sess2.run(None, {"float_input": X_test.astype(numpy.float32)})
print(res2[1][:2])
print("probabilities type:", type(res2[1]))
print("type for the first observations:", type(res2[1][0]))
[[9.7642078e-05 4.6403080e-02 9.5349926e-01]
 [1.8943824e-02 8.9788467e-01 8.3171546e-02]]
probabilities type: <class 'numpy.ndarray'>
type for the first observations: <class 'numpy.ndarray'>

One output per class#

This options removes the final operator ZipMap and splits the probabilities into columns. The final model produces one output for the label, and one output per class.

options = {id(clr): {"zipmap": "columns"}}
onx3 = convert_sklearn(
    clr, initial_types=initial_type, options=options, target_opset=12
)

sess3 = rt.InferenceSession(
    onx3.SerializeToString(), providers=["CPUExecutionProvider"]
)
res3 = sess3.run(None, {"float_input": X_test.astype(numpy.float32)})
for i, out in enumerate(sess3.get_outputs()):
    print(
        "output: '{}' shape={} values={}...".format(
            out.name, res3[i].shape, res3[i][:2]
        )
    )
output: 'output_label' shape=(38,) values=[2 1]...
output: 'i0' shape=(38,) values=[9.7642078e-05 1.8943824e-02]...
output: 'i1' shape=(38,) values=[0.04640308 0.89788467]...
output: 'i2' shape=(38,) values=[0.95349926 0.08317155]...

Let’s compare prediction time#

X32 = X_test.astype(numpy.float32)

print("Time with ZipMap:")
print(repeat(lambda: sess.run(None, {"float_input": X32}), number=100, repeat=10))

print("Time without ZipMap:")
print(repeat(lambda: sess2.run(None, {"float_input": X32}), number=100, repeat=10))

print("Time without ZipMap but with columns:")
print(repeat(lambda: sess3.run(None, {"float_input": X32}), number=100, repeat=10))

# The prediction is much faster without ZipMap
# on this example.
# The optimisation is even faster when the classes
# are described with strings and not integers
# as the final result (list of dictionaries) may copy
# many times the same information with onnxruntime.
Time with ZipMap:
[0.004339300000083313, 0.003741600000012113, 0.0030674999998154817, 0.002657600000020466, 0.002626199999667733, 0.0029488000000128523, 0.0027304000000185624, 0.0026580999997349863, 0.003430199999911565, 0.0028902999997626466]
Time without ZipMap:
[0.001563100000112172, 0.002031699999861303, 0.0012540000002445595, 0.0012501999999585678, 0.0015628999999535154, 0.0018725999998423504, 0.0012913999999000225, 0.0015812999999980093, 0.0016777000000729458, 0.0020423000000846514]
Time without ZipMap but with columns:
[0.0032931999999163963, 0.002602000000024418, 0.002499299999726645, 0.002414399999906891, 0.002675799999906303, 0.004039400000237947, 0.0024936999998317333, 0.0022391000002244255, 0.00329169999986334, 0.003487199999653967]

Versions used for this example

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx:  1.15.0
onnxruntime:  1.16.0+cu118
skl2onnx:  1.16.0

Total running time of the script: (0 minutes 0.159 seconds)

Gallery generated by Sphinx-Gallery