Note
Go to the end to download the full example code
Different ways to convert a model¶
This example leverages some code added to implement custom converters in an easy way.
Predict with onnxruntime¶
Simple function to check the converted model works fine.
import onnxruntime
import onnx
import numpy
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import KMeans
from sklearn.pipeline import make_pipeline
from onnxruntime import InferenceSession
from skl2onnx import convert_sklearn, to_onnx, wrap_as_onnx_mixin
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxSub, OnnxDiv
from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin
def predict_with_onnxruntime(onx, X):
sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
res = sess.run(None, {input_name: X.astype(np.float32)})
return res[0]
Simple KMeans¶
The first way: convert_sklearn()
.
[1 1 1 1 1 0 0 0 0 0]
The second way: to_onnx()
: no need to play with
FloatTensorType
anymore.
[0 0 0 0 0 1 1 1 1 1]
The third way: wrap_as_onnx_mixin()
: wraps
the machine learned model into a new class
inheriting from OnnxOperatorMixin
.
[0 0 0 0 0 1 1 1 1 1]
The fourth way: wrap_as_onnx_mixin()
: can be called
before fitting the model.
[1 1 1 1 1 0 0 0 0 0]
Pipeline and a custom object¶
This is a simple scaler.
class CustomOpTransformer(BaseEstimator, TransformerMixin, OnnxOperatorMixin):
def __init__(self):
BaseEstimator.__init__(self)
TransformerMixin.__init__(self)
self.op_version = 12
def fit(self, X, y=None):
self.W_ = np.mean(X, axis=0)
self.S_ = np.std(X, axis=0)
return self
def transform(self, X):
return (X - self.W_) / self.S_
def onnx_shape_calculator(self):
def shape_calculator(operator):
operator.outputs[0].type = operator.inputs[0].type
return shape_calculator
def to_onnx_operator(
self, inputs=None, outputs=("Y",), target_opset=None, **kwargs
):
if inputs is None:
raise RuntimeError("Parameter inputs should contain at least " "one name.")
opv = target_opset or self.op_version
i0 = self.get_inputs(inputs, 0)
W = self.W_.astype(np.float32)
S = self.S_.astype(np.float32)
return OnnxDiv(
OnnxSub(i0, W, op_version=12), S, output_names=outputs, op_version=opv
)
Way 1
[0 0 0 0 0 1 1 1 1 1]
Way 2
[1 1 1 1 1 0 0 0 0 0]
Way 3
[1 1 1 1 1 0 0 0 0 0]
Way 4
[0 0 0 0 0 1 1 1 1 1]
Display the ONNX graph¶
Finally, let’s see the graph converted with sklearn-onnx.
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer # noqa
pydot_graph = GetPydotGraph(
onx.graph,
name=onx.graph.name,
rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"
),
)
pydot_graph.write_dot("pipeline_onnx_mixin.dot")
import os # noqa
os.system("dot -O -Gdpi=300 -Tpng pipeline_onnx_mixin.dot")
import matplotlib.pyplot as plt # noqa
image = plt.imread("pipeline_onnx_mixin.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis("off")
(-0.5, 3103.5, 6900.5, -0.5)
Versions used for this example
import sklearn # noqa
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
import skl2onnx # noqa
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx: 1.15.0
onnxruntime: 1.16.0+cu118
skl2onnx: 1.16.0
Total running time of the script: (0 minutes 3.140 seconds)