Note
Click here to download the full example code or to run this example in your browser via Binder
Walk through intermediate outputs¶
We reuse the example Convert a pipeline with ColumnTransformer and walk through intermediates outputs. It is very likely a converted model gives different outputs or fails due to a custom converter which is not correctly implemented. One option is to look into the output of every node of the ONNX graph.
Create and train a complex pipeline¶
We reuse the pipeline implemented in example
Column Transformer with Mixed Types.
There is one change because
ONNX-ML Imputer
does not handle string type. This cannot be part of the final ONNX pipeline
and must be removed. Look for comment starting with ---
below.
import onnx
import sklearn
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
from skl2onnx.helpers.onnx_helper import select_model_inputs_outputs
from skl2onnx.helpers.onnx_helper import save_onnx_model
from skl2onnx.helpers.onnx_helper import enumerate_model_node_outputs
from skl2onnx.helpers.onnx_helper import load_onnx_model
import numpy
import onnxruntime as rt
from skl2onnx import convert_sklearn, __version__
import pprint
from skl2onnx.common.data_types import FloatTensorType, StringTensorType
from skl2onnx.common.data_types import Int64TensorType
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
titanic_url = ('https://raw.githubusercontent.com/amueller/'
'scipy-2017-sklearn/091d371/notebooks/datasets/titanic3.csv')
data = pd.read_csv(titanic_url)
X = data.drop('survived', axis=1)
y = data['survived']
# SimpleImputer on string is not available
# for string in ONNX-ML specifications.
# So we do it beforehand.
for cat in ['embarked', 'sex', 'pclass']:
X[cat].fillna('missing', inplace=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
numeric_features = ['age', 'fare']
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())])
categorical_features = ['embarked', 'sex', 'pclass']
categorical_transformer = Pipeline(steps=[
# --- SimpleImputer is not available for strings in ONNX-ML specifications.
# ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features),
])
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', LogisticRegression(solver='lbfgs'))])
clf.fit(X_train, y_train)
Out:
Pipeline(steps=[('preprocessor',
ColumnTransformer(transformers=[('num',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='median')),
('scaler',
StandardScaler())]),
['age', 'fare']),
('cat',
Pipeline(steps=[('onehot',
OneHotEncoder(handle_unknown='ignore'))]),
['embarked', 'sex',
'pclass'])])),
('classifier', LogisticRegression())])
Define the inputs of the ONNX graph¶
sklearn-onnx does not know the features used to train the model but it needs to know which feature has which name. We simply reuse the dataframe column definition.
print(X_train.dtypes)
Out:
pclass int64
name object
sex object
age float64
sibsp int64
parch int64
ticket object
fare float64
cabin object
embarked object
boat object
body float64
home.dest object
dtype: object
After conversion.
def convert_dataframe_schema(df, drop=None):
inputs = []
for k, v in zip(df.columns, df.dtypes):
if drop is not None and k in drop:
continue
if v == 'int64':
t = Int64TensorType([None, 1])
elif v == 'float64':
t = FloatTensorType([None, 1])
else:
t = StringTensorType([None, 1])
inputs.append((k, t))
return inputs
inputs = convert_dataframe_schema(X_train)
pprint.pprint(inputs)
Out:
[('pclass', Int64TensorType(shape=[None, 1])),
('name', StringTensorType(shape=[None, 1])),
('sex', StringTensorType(shape=[None, 1])),
('age', FloatTensorType(shape=[None, 1])),
('sibsp', Int64TensorType(shape=[None, 1])),
('parch', Int64TensorType(shape=[None, 1])),
('ticket', StringTensorType(shape=[None, 1])),
('fare', FloatTensorType(shape=[None, 1])),
('cabin', StringTensorType(shape=[None, 1])),
('embarked', StringTensorType(shape=[None, 1])),
('boat', StringTensorType(shape=[None, 1])),
('body', FloatTensorType(shape=[None, 1])),
('home.dest', StringTensorType(shape=[None, 1]))]
Merging single column into vectors is not the most efficient way to compute the prediction. It could be done before converting the pipeline into a graph.
Convert the pipeline into ONNX¶
try:
model_onnx = convert_sklearn(clf, 'pipeline_titanic', inputs,
target_opset=12)
except Exception as e:
print(e)
Out:
Isolated variables exist: {'parch', 'boat', 'sibsp', 'cabin', 'ticket', 'home_dest', 'body', 'name'}
scikit-learn does implicit conversions when it can. sklearn-onnx does not. The ONNX version of OneHotEncoder must be applied on columns of the same type.
X_train['pclass'] = X_train['pclass'].astype(str)
X_test['pclass'] = X_test['pclass'].astype(str)
white_list = numeric_features + categorical_features
to_drop = [c for c in X_train.columns if c not in white_list]
inputs = convert_dataframe_schema(X_train, to_drop)
model_onnx = convert_sklearn(clf, 'pipeline_titanic', inputs,
target_opset=12)
# And save.
with open("pipeline_titanic.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())
Out:
/home/dupre/github_xadupre/sklearn-onnx/docs/examples/plot_intermediate_outputs.py:143: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
X_train['pclass'] = X_train['pclass'].astype(str)
/home/dupre/github_xadupre/sklearn-onnx/docs/examples/plot_intermediate_outputs.py:144: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
X_test['pclass'] = X_test['pclass'].astype(str)
Compare the predictions¶
Final step, we need to ensure the converted model produces the same predictions, labels and probabilities. Let’s start with scikit-learn.
Out:
predict [0 1 0 1 0]
predict_proba [[0.74757949 0.25242051]]
Predictions with onnxruntime. We need to remove the dropped columns and to change the double vectors into float vectors as onnxruntime does not support double floats. onnxruntime does not accept dataframe. inputs must be given as a list of dictionary. Last detail, every column was described not really as a vector but as a matrix of one column which explains the last line with the reshape.
We are ready to run onnxruntime.
Out:
predict [0 1 0 1 0]
predict_proba [{0: 0.9081845283508301, 1: 0.09181544184684753}]
Compute intermediate outputs¶
Unfortunately, there is actually no way to ask onnxruntime to retrieve the output of intermediate nodes. We need to modifies the ONNX before it is given to onnxruntime. Let’s see first the list of intermediate output.
Out:
age_cast
fare_cast
merged_columns
variable
variable1
embarkedout
sexout
pclassout
concat_result
variable2
variable1_cast
variable2_cast
transformed_column
label
probability_tensor
probabilities
output_label
output_probability
Not that easy to tell which one is what as the ONNX has more operators than the original scikit-learn pipelines. The graph at Display the ONNX graph helps up to find the outputs of both numerical and textual pipeline: variable1, variable2. Let’s look into the numerical pipeline first.
num_onnx = select_model_inputs_outputs(model_onnx, 'variable1')
save_onnx_model(num_onnx, "pipeline_titanic_numerical.onnx")
Out:
b'\x08\x06\x12\x08skl2onnx\x1a\x051.7.1"\x07ai.onnx(\x002\x00:\xae\x04\n^\n\x08variable\x12\tvariable1\x1a\x06Scaler"\x06Scaler*\x15\n\x06offset=_\xe4\xe8A=\xc5\x1e\x03B\xa0\x01\x06*\x14\n\x05scale=/\x83\x9c==\xfb\xf1\xa2<\xa0\x01\x06:\nai.onnx.ml\n}\n\x0emerged_columns\x12\x08variable\x1a\x07Imputer"\x07Imputer*#\n\x14imputed_value_floats=\x00\x00\xe0A=gDgA\xa0\x01\x06*\x1e\n\x14replaced_value_float\x15\x00\x00\xc0\x7f\xa0\x01\x01:\nai.onnx.ml\nD\n\x08age_cast\n\tfare_cast\x12\x0emerged_columns\x1a\x06Concat"\x06Concat*\x0b\n\x04axis\x18\x01\xa0\x01\x02:\x00\n+\n\x04fare\x12\tfare_cast\x1a\x05Cast1"\x04Cast*\t\n\x02to\x18\x01\xa0\x01\x02:\x00\n(\n\x03age\x12\x08age_cast\x1a\x04Cast"\x04Cast*\t\n\x02to\x18\x01\xa0\x01\x02:\x00\x12\x10pipeline_titanic*\x1f\x08\x02\x10\x07:\x0b\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\tB\x0cshape_tensorZ\x16\n\x06pclass\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01Z\x13\n\x03sex\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01Z\x13\n\x03age\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\x01Z\x14\n\x04fare\x12\x0c\n\n\x08\x01\x12\x06\n\x00\n\x02\x08\x01Z\x18\n\x08embarked\x12\x0c\n\n\x08\x08\x12\x06\n\x00\n\x02\x08\x01b\x0b\n\tvariable1B\x02\x10\x0cB\x04\n\x00\x10\x0bB\x0e\n\nai.onnx.ml\x10\x01'
Let’s compute the numerical features.
Out:
numerical features [[-0.6198985 -0.49189988]]
We do the same for the textual features.
Out:
ir_version: 6
producer_name: "skl2onnx"
producer_version: "1.7.1"
domain: "ai.onnx"
model_version: 0
doc_string: ""
graph {
node {
input: "age"
output: "age_cast"
name: "Cast"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
domain: ""
}
node {
input: "fare"
output: "fare_cast"
name: "Cast1"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
domain: ""
}
node {
input: "age_cast"
input: "fare_cast"
output: "merged_columns"
name: "Concat"
op_type: "Concat"
attribute {
name: "axis"
i: 1
type: INT
}
domain: ""
}
node {
input: "merged_columns"
output: "variable"
name: "Imputer"
op_type: "Imputer"
attribute {
name: "imputed_value_floats"
floats: 28.0
floats: 14.45419979095459
type: FLOATS
}
attribute {
name: "replaced_value_float"
f: nan
type: FLOAT
}
domain: "ai.onnx.ml"
}
node {
input: "variable"
output: "variable1"
name: "Scaler"
op_type: "Scaler"
attribute {
name: "offset"
floats: 29.111509323120117
floats: 32.78004837036133
type: FLOATS
}
attribute {
name: "scale"
floats: 0.07642208784818649
floats: 0.01989077590405941
type: FLOATS
}
domain: "ai.onnx.ml"
}
node {
input: "embarked"
output: "embarkedout"
name: "OneHotEncoder"
op_type: "OneHotEncoder"
attribute {
name: "cats_strings"
strings: "C"
strings: "Q"
strings: "S"
strings: "missing"
type: STRINGS
}
attribute {
name: "zeros"
i: 1
type: INT
}
domain: "ai.onnx.ml"
}
node {
input: "sex"
output: "sexout"
name: "OneHotEncoder1"
op_type: "OneHotEncoder"
attribute {
name: "cats_strings"
strings: "female"
strings: "male"
type: STRINGS
}
attribute {
name: "zeros"
i: 1
type: INT
}
domain: "ai.onnx.ml"
}
node {
input: "pclass"
output: "pclassout"
name: "OneHotEncoder2"
op_type: "OneHotEncoder"
attribute {
name: "cats_strings"
strings: "1"
strings: "2"
strings: "3"
type: STRINGS
}
attribute {
name: "zeros"
i: 1
type: INT
}
domain: "ai.onnx.ml"
}
node {
input: "embarkedout"
input: "sexout"
input: "pclassout"
output: "concat_result"
name: "Concat1"
op_type: "Concat"
attribute {
name: "axis"
i: 2
type: INT
}
domain: ""
}
node {
input: "concat_result"
input: "shape_tensor"
output: "variable2"
name: "Reshape"
op_type: "Reshape"
domain: ""
}
node {
input: "variable1"
output: "variable1_cast"
name: "Cast2"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
domain: ""
}
node {
input: "variable2"
output: "variable2_cast"
name: "Cast3"
op_type: "Cast"
attribute {
name: "to"
i: 1
type: INT
}
domain: ""
}
node {
input: "variable1_cast"
input: "variable2_cast"
output: "transformed_column"
name: "Concat2"
op_type: "Concat"
attribute {
name: "axis"
i: 1
type: INT
}
domain: ""
}
node {
input: "transformed_column"
output: "label"
output: "probability_tensor"
name: "LinearClassifier"
op_type: "LinearClassifier"
attribute {
name: "classlabels_ints"
ints: 0
ints: 1
type: INTS
}
attribute {
name: "coefficients"
floats: 0.4088466167449951
floats: 0.048030976206064224
floats: -0.3020475506782532
floats: 0.1128229945898056
floats: 0.4023498296737671
floats: -0.21311041712760925
floats: -1.2705841064453125
floats: 1.2705990076065063
floats: -1.084059238433838
floats: -0.121848464012146
floats: 1.2059226036071777
floats: -0.4088466167449951
floats: -0.048030976206064224
floats: 0.3020475506782532
floats: -0.1128229945898056
floats: -0.4023498296737671
floats: 0.21311041712760925
floats: 1.2705841064453125
floats: -1.2705990076065063
floats: 1.084059238433838
floats: 0.121848464012146
floats: -1.2059226036071777
type: FLOATS
}
attribute {
name: "intercepts"
floats: -0.3101347088813782
floats: 0.3101347088813782
type: FLOATS
}
attribute {
name: "multi_class"
i: 1
type: INT
}
attribute {
name: "post_transform"
s: "LOGISTIC"
type: STRING
}
domain: "ai.onnx.ml"
}
node {
input: "probability_tensor"
output: "probabilities"
name: "Normalizer"
op_type: "Normalizer"
attribute {
name: "norm"
s: "L1"
type: STRING
}
domain: "ai.onnx.ml"
}
node {
input: "label"
output: "output_label"
name: "Cast4"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
domain: ""
}
node {
input: "probabilities"
output: "output_probability"
name: "ZipMap"
op_type: "ZipMap"
attribute {
name: "classlabels_int64s"
ints: 0
ints: 1
type: INTS
}
domain: "ai.onnx.ml"
}
name: "pipeline_titanic"
initializer {
dims: 2
data_type: 7
int64_data: -1
int64_data: 9
name: "shape_tensor"
}
input {
name: "pclass"
type {
tensor_type {
elem_type: 8
shape {
dim {
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "sex"
type {
tensor_type {
elem_type: 8
shape {
dim {
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "age"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "fare"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
dim {
dim_value: 1
}
}
}
}
}
input {
name: "embarked"
type {
tensor_type {
elem_type: 8
shape {
dim {
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "output_label"
type {
tensor_type {
elem_type: 7
shape {
dim {
}
}
}
}
}
output {
name: "output_probability"
type {
sequence_type {
elem_type {
map_type {
key_type: 7
value_type {
tensor_type {
elem_type: 1
}
}
}
}
}
}
}
}
opset_import {
domain: ""
version: 11
}
opset_import {
domain: "ai.onnx.ml"
version: 1
}
textual features [[0. 0. 1. 0. 0. 1. 0. 0. 1.]]
Display the sub-ONNX graph¶
Finally, let’s see both subgraphs. First, numerical pipeline.
pydot_graph = GetPydotGraph(
num_onnx.graph, name=num_onnx.graph.name, rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"))
pydot_graph.write_dot("pipeline_titanic_num.dot")
os.system('dot -O -Gdpi=300 -Tpng pipeline_titanic_num.dot')
image = plt.imread("pipeline_titanic_num.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis('off')

Out:
(-0.5, 4393.5, 1033.5, -0.5)
Then textual pipeline.
pydot_graph = GetPydotGraph(
text_onnx.graph, name=text_onnx.graph.name, rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"))
pydot_graph.write_dot("pipeline_titanic_text.dot")
os.system('dot -O -Gdpi=300 -Tpng pipeline_titanic_text.dot')
image = plt.imread("pipeline_titanic_text.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis('off')

Out:
(-0.5, 6847.5, 1121.5, -0.5)
Versions used for this example
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", __version__)
Out:
numpy: 1.19.1
scikit-learn: 0.23.2
onnx: 1.7.0
onnxruntime: 1.4.0
skl2onnx: 1.7.1
Total running time of the script: ( 0 minutes 3.102 seconds)