Note
Go to the end to download the full example code.
ONNX Runtime Backend for ONNX¶
ONNX Runtime extends the onnx backend API to run predictions using this runtime. Let’s use the API to compute the prediction of a simple logistic regression model.
import skl2onnx
import onnxruntime
import onnx
import sklearn
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
import numpy
from onnxruntime import get_device
import numpy as np
import onnxruntime.backend as backend
Let’s create an ONNX graph first.
data = load_iris()
X, Y = data.data, data.target
logreg = LogisticRegression(C=1e5).fit(X, Y)
model = skl2onnx.to_onnx(logreg, X.astype(np.float32))
name = "logreg_iris.onnx"
with open(name, "wb") as f:
f.write(model.SerializeToString())
Let’s use ONNX backend API to test it.
model = onnx.load(name)
rep = backend.prepare(model)
x = np.array(
[[-1.0, -2.0, 5.0, 6.0], [-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, 7.0, 8.0]],
dtype=np.float32,
)
label, proba = rep.run(x)
print("label={}".format(label))
print("probabilities={}".format(proba))
label=[2 0 2]
probabilities=[{0: 0.0, 1: 0.0, 2: 1.0}, {0: 1.0, 1: 1.9515885113950192e-38, 2: 0.0}, {0: 0.0, 1: 0.0, 2: 1.0}]
The device depends on how the package was compiled, GPU or CPU.
print(get_device())
GPU
The backend can also directly load the model without using onnx.
rep = backend.prepare(name)
x = np.array(
[[-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, -3.0, -4.0], [-1.0, -2.0, -3.0, -4.0]],
dtype=np.float32,
)
label, proba = rep.run(x)
print("label={}".format(label))
print("probabilities={}".format(proba))
label=[0 0 0]
probabilities=[{0: 1.0, 1: 1.9515885113950192e-38, 2: 0.0}, {0: 1.0, 1: 1.9515885113950192e-38, 2: 0.0}, {0: 1.0, 1: 1.9515885113950192e-38, 2: 0.0}]
The backend API is implemented by other frameworks and makes it easier to switch between multiple runtimes with the same API.
Versions used for this example
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 2.2.0
scikit-learn: 1.6.0
onnx: 1.18.0
onnxruntime: 1.21.0+cu126
skl2onnx: 1.18.0
Total running time of the script: (0 minutes 7.881 seconds)