Transfer Learning with ONNX#

Transfer learning is common with deep learning. A deep learning model is used as preprocessing before the output is sent to a final classifier or regressor. It is not quite easy in this case to mix framework, scikit-learn with pytorch (or skorch), the Keras API for Tensorflow, tf.keras.wrappers.scikit_learn. Every combination requires work. ONNX reduces the number of platforms to support. Once the model is converted into ONNX, it can be inserted in any scikit-learn pipeline.

Retrieve and load a model#

We download one model from the :epkg:`ONNX Zoo` but the model could be trained and produced by another converter library.

import sys
from io import BytesIO
import onnx
from mlprodict.sklapi import OnnxTransformer
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from mlinsights.plotting.gallery import plot_gallery_images
import matplotlib.pyplot as plt
from skl2onnx.tutorial.imagenet_classes import class_names
import numpy
from PIL import Image
from onnxruntime import InferenceSession
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument
import os
import urllib.request


def download_file(url, name, min_size):
    if not os.path.exists(name):
        print("download '%s'" % url)
        with urllib.request.urlopen(url) as u:
            content = u.read()
        if len(content) < min_size:
            raise RuntimeError(
                "Unable to download '{}' due to\n{}".format(
                    url, content))
        print("downloaded %d bytes." % len(content))
        with open(name, "wb") as f:
            f.write(content)
    else:
        print("'%s' already downloaded" % name)


model_name = "squeezenet1.1-7.onnx"
url_name = ("https://github.com/onnx/models/raw/main/vision/"
            "classification/squeezenet/model")
url_name += "/" + model_name
try:
    download_file(url_name, model_name, 100000)
except RuntimeError as e:
    print(e)
    sys.exit(1)

Out:

download 'https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.1-7.onnx'
downloaded 4956208 bytes.

Loading the ONNX file and use it on one image.

Out:

NodeArg(name='data', type='tensor(float)', shape=[1, 3, 224, 224])

The model expects a series of images of size [3, 224, 224].

Classifying an image#

url = ("https://upload.wikimedia.org/wikipedia/commons/d/d2/"
       "East_Coker_elm%2C_2.jpg")
img = "East_Coker_elm.jpg"
download_file(url, img, 100000)

im0 = Image.open(img)
im = im0.resize((224, 224))
# im.show()

Out:

download 'https://upload.wikimedia.org/wikipedia/commons/d/d2/East_Coker_elm%2C_2.jpg'
downloaded 712230 bytes.

Image to numpy and predection.

def im2array(im):
    X = numpy.asarray(im)
    X = X.transpose(2, 0, 1)
    X = X.reshape(1, 3, 224, 224)
    return X


X = im2array(im)
out = sess.run(None, {'data': X.astype(numpy.float32)})
out = out[0]

print(out[0, :5])

Out:

[145.59464   55.06765   60.599792  46.293953  37.982464]

Interpretation

res = list(sorted((r, class_names[i]) for i, r in enumerate(out[0])))
print(res[-5:])

Out:

[(205.84174, 'Samoyed, Samoyede'), (212.03664, 'park bench'), (225.50691, 'lakeside, lakeshore'), (232.90251, 'fountain'), (258.10968, 'geyser')]

Classifying more images#

The initial image is rotated, the answer is changing.

angles = [a * 2. for a in range(-6, 6)]
imgs = [(angle, im0.rotate(angle).resize((224, 224)))
        for angle in angles]


def classify(imgs):
    labels = []
    for angle, img in imgs:
        X = im2array(img)
        probs = sess.run(None, {'data': X.astype(numpy.float32)})[0]
        pl = list(sorted(
            ((r, class_names[i]) for i, r in enumerate(probs[0])),
            reverse=True))
        labels.append((angle, pl))
    return labels


climgs = classify(imgs)
for angle, res in climgs:
    print("angle={} - {}".format(angle, res[:5]))


plot_gallery_images([img[1] for img in imgs],
                    [img[1][0][1][:15] for img in climgs])
plot gbegin transfer learning

Out:

angle=-12.0 - [(247.06146, 'obelisk'), (238.95372, 'car mirror'), (235.27646, 'flagpole, flagstaff'), (231.51707, 'window screen'), (230.90657, 'picket fence, paling')]
angle=-10.0 - [(254.24683, 'car mirror'), (251.51357, 'obelisk'), (235.10512, 'groom, bridegroom'), (234.5295, 'picket fence, paling'), (232.13913, 'church, church building')]
angle=-8.0 - [(235.56952, 'obelisk'), (226.59697, 'car mirror'), (226.46773, 'picket fence, paling'), (221.46794, 'groom, bridegroom'), (220.88506, 'fountain')]
angle=-6.0 - [(265.50806, 'geyser'), (243.68619, 'obelisk'), (238.92957, 'fountain'), (226.73683, 'pedestal, plinth, footstall'), (226.11952, 'Great Pyrenees')]
angle=-4.0 - [(287.7449, 'geyser'), (255.25323, 'fountain'), (236.84944, 'obelisk'), (223.02913, 'Great Pyrenees'), (222.80464, 'church, church building')]
angle=-2.0 - [(267.63528, 'geyser'), (251.48958, 'fountain'), (214.64241, 'obelisk'), (214.56227, 'mobile home, manufactured home'), (213.12424, 'flagpole, flagstaff')]
angle=0.0 - [(258.10968, 'geyser'), (232.90251, 'fountain'), (225.50691, 'lakeside, lakeshore'), (212.03664, 'park bench'), (205.84174, 'Samoyed, Samoyede')]
angle=2.0 - [(222.74826, 'geyser'), (213.38457, 'fountain'), (212.24376, 'obelisk'), (198.3714, 'beacon, lighthouse, beacon light, pharos'), (197.43805, 'picket fence, paling')]
angle=4.0 - [(221.34749, 'geyser'), (209.60362, 'fountain'), (207.0692, 'American egret, great white heron, Egretta albus'), (201.63098, 'obelisk'), (198.75673, 'Great Pyrenees')]
angle=6.0 - [(230.98735, 'American egret, great white heron, Egretta albus'), (216.6342, 'fountain'), (212.73236, 'groom, bridegroom'), (209.60934, 'flagpole, flagstaff'), (209.46207, 'swimming trunks, bathing trunks')]
angle=8.0 - [(253.32706, 'American egret, great white heron, Egretta albus'), (222.6997, 'golf ball'), (222.50499, 'groom, bridegroom'), (222.36351, 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita'), (217.73135, 'swimming trunks, bathing trunks')]
angle=10.0 - [(244.3011, 'solar dish, solar collector, solar furnace'), (239.57332, 'flagpole, flagstaff'), (234.92139, 'picket fence, paling'), (230.62114, 'car mirror'), (221.8794, 'screen, CRT screen')]

array([[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]],
      dtype=object)

Transfer learning in a pipeline#

The proposed transfer learning consists using a PCA to projet the probabilities on a graph.

with open(model_name, 'rb') as f:
    model_bytes = f.read()

pipe = Pipeline(steps=[
    ('deep', OnnxTransformer(
        model_bytes, runtime='onnxruntime1', change_batch_size=0)),
    ('pca', PCA(2))
])

X_train = numpy.vstack(
    [im2array(img) for _, img in imgs]).astype(numpy.float32)
pipe.fit(X_train)

proj = pipe.transform(X_train)
print(proj)

Out:

[[-676.57574  -203.35419 ]
 [-570.6652   -208.09737 ]
 [-339.8119    -86.340065]
 [ -14.556101 -168.44867 ]
 [ 357.2234   -157.61446 ]
 [ 596.38617   -90.210075]
 [ 918.86115   -26.339502]
 [ 499.8716    128.27264 ]
 [ 306.68567   156.4289  ]
 [-125.912094  119.218   ]
 [-446.60455   342.4585  ]
 [-504.9024    194.02617 ]]

Graph for the PCA#

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(proj[:, 0], proj[:, 1], 'o')
ax.set_title("Projection of classification probabilities")
text = ["%1.0f-%s" % (el[0], el[1][0][1]) for el in climgs]
for label, x, y in zip(text, proj[:, 0], proj[:, 1]):
    ax.annotate(
        label, xy=(x, y), xytext=(-10, 10), fontsize=8,
        textcoords='offset points', ha='right', va='bottom',
        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
Projection of classification probabilities

Remove one layer at the end#

The last is often removed before the model is inserted in a pipeline. Let’s see how to do that. First, we need the list of output for every node.

model_onnx = onnx.load(BytesIO(model_bytes))
outputs = []
for node in model_onnx.graph.node:
    print(node.name, node.output)
    outputs.extend(node.output)

Out:

squeezenet0_conv0_fwd ['squeezenet0_conv0_fwd']
squeezenet0_relu0_fwd ['squeezenet0_relu0_fwd']
squeezenet0_pool0_fwd ['squeezenet0_pool0_fwd']
squeezenet0_conv1_fwd ['squeezenet0_conv1_fwd']
squeezenet0_relu1_fwd ['squeezenet0_relu1_fwd']
squeezenet0_conv2_fwd ['squeezenet0_conv2_fwd']
squeezenet0_relu2_fwd ['squeezenet0_relu2_fwd']
squeezenet0_conv3_fwd ['squeezenet0_conv3_fwd']
squeezenet0_relu3_fwd ['squeezenet0_relu3_fwd']
squeezenet0_concat0 ['squeezenet0_concat0']
squeezenet0_conv4_fwd ['squeezenet0_conv4_fwd']
squeezenet0_relu4_fwd ['squeezenet0_relu4_fwd']
squeezenet0_conv5_fwd ['squeezenet0_conv5_fwd']
squeezenet0_relu5_fwd ['squeezenet0_relu5_fwd']
squeezenet0_conv6_fwd ['squeezenet0_conv6_fwd']
squeezenet0_relu6_fwd ['squeezenet0_relu6_fwd']
squeezenet0_concat1 ['squeezenet0_concat1']
squeezenet0_pool1_fwd ['squeezenet0_pool1_fwd']
squeezenet0_conv7_fwd ['squeezenet0_conv7_fwd']
squeezenet0_relu7_fwd ['squeezenet0_relu7_fwd']
squeezenet0_conv8_fwd ['squeezenet0_conv8_fwd']
squeezenet0_relu8_fwd ['squeezenet0_relu8_fwd']
squeezenet0_conv9_fwd ['squeezenet0_conv9_fwd']
squeezenet0_relu9_fwd ['squeezenet0_relu9_fwd']
squeezenet0_concat2 ['squeezenet0_concat2']
squeezenet0_conv10_fwd ['squeezenet0_conv10_fwd']
squeezenet0_relu10_fwd ['squeezenet0_relu10_fwd']
squeezenet0_conv11_fwd ['squeezenet0_conv11_fwd']
squeezenet0_relu11_fwd ['squeezenet0_relu11_fwd']
squeezenet0_conv12_fwd ['squeezenet0_conv12_fwd']
squeezenet0_relu12_fwd ['squeezenet0_relu12_fwd']
squeezenet0_concat3 ['squeezenet0_concat3']
squeezenet0_pool2_fwd ['squeezenet0_pool2_fwd']
squeezenet0_conv13_fwd ['squeezenet0_conv13_fwd']
squeezenet0_relu13_fwd ['squeezenet0_relu13_fwd']
squeezenet0_conv14_fwd ['squeezenet0_conv14_fwd']
squeezenet0_relu14_fwd ['squeezenet0_relu14_fwd']
squeezenet0_conv15_fwd ['squeezenet0_conv15_fwd']
squeezenet0_relu15_fwd ['squeezenet0_relu15_fwd']
squeezenet0_concat4 ['squeezenet0_concat4']
squeezenet0_conv16_fwd ['squeezenet0_conv16_fwd']
squeezenet0_relu16_fwd ['squeezenet0_relu16_fwd']
squeezenet0_conv17_fwd ['squeezenet0_conv17_fwd']
squeezenet0_relu17_fwd ['squeezenet0_relu17_fwd']
squeezenet0_conv18_fwd ['squeezenet0_conv18_fwd']
squeezenet0_relu18_fwd ['squeezenet0_relu18_fwd']
squeezenet0_concat5 ['squeezenet0_concat5']
squeezenet0_conv19_fwd ['squeezenet0_conv19_fwd']
squeezenet0_relu19_fwd ['squeezenet0_relu19_fwd']
squeezenet0_conv20_fwd ['squeezenet0_conv20_fwd']
squeezenet0_relu20_fwd ['squeezenet0_relu20_fwd']
squeezenet0_conv21_fwd ['squeezenet0_conv21_fwd']
squeezenet0_relu21_fwd ['squeezenet0_relu21_fwd']
squeezenet0_concat6 ['squeezenet0_concat6']
squeezenet0_conv22_fwd ['squeezenet0_conv22_fwd']
squeezenet0_relu22_fwd ['squeezenet0_relu22_fwd']
squeezenet0_conv23_fwd ['squeezenet0_conv23_fwd']
squeezenet0_relu23_fwd ['squeezenet0_relu23_fwd']
squeezenet0_conv24_fwd ['squeezenet0_conv24_fwd']
squeezenet0_relu24_fwd ['squeezenet0_relu24_fwd']
squeezenet0_concat7 ['squeezenet0_concat7']
squeezenet0_dropout0_fwd ['squeezenet0_dropout0_fwd']
squeezenet0_conv25_fwd ['squeezenet0_conv25_fwd']
squeezenet0_relu25_fwd ['squeezenet0_relu25_fwd']
squeezenet0_pool3_fwd ['squeezenet0_pool3_fwd']
squeezenet0_flatten0_reshape0 ['squeezenet0_flatten0_reshape0']

We select one of the last one.

selected = outputs[-3]
print("selected", selected)

Out:

selected squeezenet0_relu25_fwd

And we tell OnnxTransformer to use that specific one and to flatten the output as the dimension is not a matrix.

pipe2 = Pipeline(steps=[
    ('deep', OnnxTransformer(
        model_bytes, runtime='onnxruntime1', change_batch_size=0,
        output_name=selected, reshape=True)),
    ('pca', PCA(2))
])

try:
    pipe2.fit(X_train)
except InvalidArgument as e:
    print("Unable to fit due to", e)

We check that it is different. The following values are the shape of the PCA components. The number of column is the number of dimensions of the outputs of the transfered neural network.

print(pipe.steps[1][1].components_.shape,
      pipe2.steps[1][1].components_.shape)

Out:

(2, 1000) (2, 169000)

Graph again.

proj2 = pipe2.transform(X_train)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(proj2[:, 0], proj2[:, 1], 'o')
ax.set_title("Second projection of classification probabilities")
text = ["%1.0f-%s" % (el[0], el[1][0][1]) for el in climgs]
for label, x, y in zip(text, proj2[:, 0], proj2[:, 1]):
    ax.annotate(
        label, xy=(x, y), xytext=(-10, 10), fontsize=8,
        textcoords='offset points', ha='right', va='bottom',
        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
        arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
Second projection of classification probabilities

Total running time of the script: ( 0 minutes 5.373 seconds)

Gallery generated by Sphinx-Gallery