Compare CDist with scipy#

The following example focuses on one particular operator, CDist and compares its execution time between onnxruntime and scipy.

ONNX Graph with CDist#

cdist function computes pairwise distances.

from pprint import pprint
from timeit import Timer
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from pandas import DataFrame
import onnx
import onnxruntime as rt
from onnxruntime import InferenceSession
import skl2onnx
from skl2onnx.algebra.custom_ops import OnnxCDist
from skl2onnx.common.data_types import FloatTensorType

X = np.ones((2, 4), dtype=np.float32)
Y = np.ones((3, 4), dtype=np.float32)
Y *= 2
print(cdist(X, Y, metric="euclidean"))
[[2. 2. 2.]
 [2. 2. 2.]]

ONNX

op = OnnxCDist("X", "Y", op_version=12, output_names=["Z"], metric="euclidean")
onx = op.to_onnx({"X": X, "Y": Y}, outputs=[("Z", FloatTensorType())])
print(onx)
ir_version: 8
opset_import {
  domain: "com.microsoft"
  version: 1
}
producer_name: "skl2onnx"
producer_version: "1.15.0"
domain: "ai.onnx"
model_version: 0
graph {
  node {
    input: "X"
    input: "Y"
    output: "Z"
    name: "CD_CDist"
    op_type: "CDist"
    domain: "com.microsoft"
    attribute {
      name: "metric"
      type: STRING
      s: "euclidean"
    }
  }
  name: "OnnxCDist"
  input {
    name: "X"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  input {
    name: "Y"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  output {
    name: "Z"
    type {
      tensor_type {
        elem_type: 1
      }
    }
  }
}

CDist and onnxruntime#

We compute the output of CDist operator with onnxruntime.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"X": X, "Y": Y})
print(res)
[array([[1.9999999, 1.9999999, 1.9999999],
       [1.9999999, 2.       , 2.       ]], dtype=float32)]

Benchmark#

Let’s compare onnxruntime and scipy.

def measure_time(name, stmt, context, repeat=100, number=20):
    tim = Timer(stmt, globals=context)
    res = np.array(tim.repeat(repeat=repeat, number=number))
    res /= number
    mean = np.mean(res)
    dev = np.mean(res**2)
    dev = (dev - mean**2) ** 0.5
    return dict(
        average=mean,
        deviation=dev,
        min_exec=np.min(res),
        max_exec=np.max(res),
        repeat=repeat,
        number=number,
        nrows=context["X"].shape[0],
        ncols=context["Y"].shape[1],
        name=name,
    )

scipy

time_scipy = measure_time(
    "scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
)
pprint(time_scipy)
{'average': 9.84484999992219e-06,
 'deviation': 4.450414501003498e-06,
 'max_exec': 4.325500000277316e-05,
 'min_exec': 3.990000004705508e-06,
 'name': 'scipy',
 'ncols': 4,
 'nrows': 2,
 'number': 20,
 'repeat': 100}

onnxruntime

time_ort = measure_time(
    "ort", "sess.run(None, {'X': X, 'Y': Y})", context={"sess": sess, "X": X, "Y": Y}
)
pprint(time_ort)
{'average': 1.90378500000179e-05,
 'deviation': 9.14980003126398e-06,
 'max_exec': 5.8449999994536486e-05,
 'min_exec': 1.1034999999992579e-05,
 'name': 'ort',
 'ncols': 4,
 'nrows': 2,
 'number': 20,
 'repeat': 100}

Longer benchmark

metrics = []
for dim in tqdm([10, 100, 1000, 10000]):
    # We cannot change the number of column otherwise
    # we need to create a new graph.
    X = np.random.randn(dim, 4).astype(np.float32)
    Y = np.random.randn(10, 4).astype(np.float32)

    time_scipy = measure_time(
        "scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
    )
    time_ort = measure_time(
        "ort",
        "sess.run(None, {'X': X, 'Y': Y})",
        context={"sess": sess, "X": X, "Y": Y},
    )
    metric = dict(N=dim, scipy=time_scipy["average"], ort=time_ort["average"])
    metrics.append(metric)

df = DataFrame(metrics)
df["scipy/ort"] = df["scipy"] / df["ort"]
print(df)

df.plot(x="N", y=["scipy/ort"])
plot benchmark cdist
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00, 19.40it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]
100%|██████████| 4/4 [00:02<00:00,  1.64it/s]
       N     scipy       ort  scipy/ort
0     10  0.000010  0.000014   0.695364
1    100  0.000012  0.000015   0.824828
2   1000  0.000095  0.000041   2.312891
3  10000  0.000741  0.000287   2.578285

Versions used for this example

print("numpy:", np.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
onnx:  1.15.0
onnxruntime:  1.16.0+cu118
skl2onnx:  1.15.0

Total running time of the script: (0 minutes 2.761 seconds)

Gallery generated by Sphinx-Gallery