Dataframe as an input#

A pipeline usually ingests data as a matrix. It may be converted in a matrix if all the data share the same type. But data held in a dataframe have usually multiple types, float, integer or string for categories. ONNX also supports that case.

A dataset with categories#

import numpy
import pprint
from onnxruntime import InferenceSession
from pandas import DataFrame
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx
from skl2onnx.algebra.type_helper import guess_initial_types


data = DataFrame(
    [
        dict(CAT1="a", CAT2="c", num1=0.5, num2=0.6, y=0),
        dict(CAT1="b", CAT2="d", num1=0.4, num2=0.8, y=1),
        dict(CAT1="a", CAT2="d", num1=0.5, num2=0.56, y=0),
        dict(CAT1="a", CAT2="d", num1=0.55, num2=0.56, y=1),
        dict(CAT1="a", CAT2="c", num1=0.35, num2=0.86, y=0),
        dict(CAT1="a", CAT2="c", num1=0.5, num2=0.68, y=1),
    ]
)

cat_cols = ["CAT1", "CAT2"]
train_data = data.drop("y", axis=1)


categorical_transformer = Pipeline(
    [("onehot", OneHotEncoder(sparse_output=False, handle_unknown="ignore"))]
)
preprocessor = ColumnTransformer(
    transformers=[("cat", categorical_transformer, cat_cols)], remainder="passthrough"
)
pipe = Pipeline([("preprocess", preprocessor), ("rf", RandomForestClassifier())])
pipe.fit(train_data, data["y"])
Pipeline(steps=[('preprocess',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('cat',
                                                  Pipeline(steps=[('onehot',
                                                                   OneHotEncoder(handle_unknown='ignore',
                                                                                 sparse_output=False))]),
                                                  ['CAT1', 'CAT2'])])),
                ('rf', RandomForestClassifier())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Conversion to ONNX#

Function to_onnx does not handle dataframes.

onx = to_onnx(pipe, train_data[:1], options={RandomForestClassifier: {"zipmap": False}})

Prediction with ONNX#

onnxruntime does not support dataframes.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
try:
    sess.run(None, train_data)
except Exception as e:
    print(e)

# Unhide conversion logic with a dataframe
# ++++++++++++++++++++++++++++++++++++++++
#
# A dataframe can be seen as a set of columns with
# different types. That's what ONNX should see:
# a list of inputs, the input name is the column name,
# the input type is the column type.


def guess_schema_from_data(X):
    init = guess_initial_types(X)
    unique = set()
    for _, col in init:
        if len(col.shape) != 2:
            return init
        if col.shape[0] is not None:
            return init
        if len(unique) > 0 and col.__class__ not in unique:
            return init
        unique.add(col.__class__)
    unique = list(unique)
    return [("X", unique[0]([None, sum(_[1].shape[1] for _ in init)]))]


init = guess_schema_from_data(train_data)

pprint.pprint(init)
run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f029b60e570>, ['label', 'probabilities'],   CAT1 CAT2  num1  num2
0    a    c  0.50  0.60
1    b    d  0.40  0.80
2    a    d  0.50  0.56
3    a    d  0.55  0.56
4    a    c  0.35  0.86
5    a    c  0.50  0.68, None
[('CAT1', StringTensorType(shape=[None, 1])),
 ('CAT2', StringTensorType(shape=[None, 1])),
 ('num1', DoubleTensorType(shape=[None, 1])),
 ('num2', DoubleTensorType(shape=[None, 1]))]

Let’s use float instead.

for c in train_data.columns:
    if c not in cat_cols:
        train_data[c] = train_data[c].astype(numpy.float32)


init = guess_schema_from_data(train_data)
pprint.pprint(init)
[('CAT1', StringTensorType(shape=[None, 1])),
 ('CAT2', StringTensorType(shape=[None, 1])),
 ('num1', FloatTensorType(shape=[None, 1])),
 ('num2', FloatTensorType(shape=[None, 1]))]

Let’s convert with skl2onnx only.

onx2 = to_onnx(
    pipe, initial_types=init, options={RandomForestClassifier: {"zipmap": False}}
)

Let’s run it with onnxruntime. We need to convert the dataframe into a dictionary where column names become keys, and column values become values.

inputs = {c: train_data[c].values.reshape((-1, 1)) for c in train_data.columns}
pprint.pprint(inputs)
{'CAT1': array([['a'],
       ['b'],
       ['a'],
       ['a'],
       ['a'],
       ['a']], dtype=object),
 'CAT2': array([['c'],
       ['d'],
       ['d'],
       ['d'],
       ['c'],
       ['c']], dtype=object),
 'num1': array([[0.5 ],
       [0.4 ],
       [0.5 ],
       [0.55],
       [0.35],
       [0.5 ]], dtype=float32),
 'num2': array([[0.6 ],
       [0.8 ],
       [0.56],
       [0.56],
       [0.86],
       [0.68]], dtype=float32)}

Inference.

sess2 = InferenceSession(onx2.SerializeToString(), providers=["CPUExecutionProvider"])

got2 = sess2.run(None, inputs)

print(pipe.predict(train_data))
print(got2[0])
[0 1 0 1 0 1]
[0 1 0 1 0 1]

And probilities.

[[0.82 0.18]
 [0.26 0.74]
 [0.76 0.24]
 [0.37 0.63]
 [0.75 0.25]
 [0.29 0.71]]
[[0.82       0.18      ]
 [0.2600004  0.7399996 ]
 [0.76       0.24000004]
 [0.3700003  0.6299997 ]
 [0.75       0.25000003]
 [0.29000038 0.7099996 ]]

Total running time of the script: (0 minutes 0.412 seconds)

Gallery generated by Sphinx-Gallery