.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_nmf.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_nmf.py: Custom Operator for NMF Decomposition ===================================== `NMF `_ factorizes an input matrix into two matrices *W, H* of rank *k* so that :math:`WH \sim M``. :math:`M=(m_{ij})` may be a binary matrix where *i* is a user and *j* a product he bought. The prediction function depends on whether or not the user needs a recommandation for an existing user or a new user. This example addresses the first case. The second case is more complex as it theoretically requires the estimation of a new matrix *W* with a gradient descent. Building a simple model +++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 25-59 .. code-block:: default import os import skl2onnx import onnxruntime import sklearn from sklearn.decomposition import NMF import numpy as np import matplotlib.pyplot as plt from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnx from skl2onnx.algebra.onnx_ops import OnnxArrayFeatureExtractor, OnnxMul, OnnxReduceSum from skl2onnx.common.data_types import FloatTensorType from onnxruntime import InferenceSession mat = np.array( [[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]], dtype=np.float64, ) mat[: mat.shape[1], :] += np.identity(mat.shape[1]) mod = NMF(n_components=2) W = mod.fit_transform(mat) H = mod.components_ pred = mod.inverse_transform(W) print("original predictions") exp = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): exp.append((i, j, pred[i, j])) print(exp) .. rst-class:: sphx-glr-script-out .. code-block:: none original predictions [(0, 0, 1.8940507356352687), (0, 1, 0.10912372262184848), (0, 2, 0.3072453141962623), (0, 3, 0.3072453141962623), (1, 0, 1.014673742790866), (1, 1, 0.9848866016414943), (1, 2, 0.0), (1, 3, 0.0), (2, 0, 1.1066115912409111), (2, 1, 0.0), (2, 2, 0.19083752823558645), (2, 3, 0.19083752823558645), (3, 0, 1.1066115912409111), (3, 1, 0.0), (3, 2, 0.19083752823558645), (3, 3, 0.19083752823558645), (4, 0, 0.9470253678176344), (4, 1, 0.05456186131092424), (4, 2, 0.15362265709813114), (4, 3, 0.15362265709813114)] .. GENERATED FROM PYTHON SOURCE LINES 60-62 Let's rewrite the prediction in a way it is closer to the function we need to convert into ONNX. .. GENERATED FROM PYTHON SOURCE LINES 62-76 .. code-block:: default def predict(W, H, row_index, col_index): return np.dot(W[row_index, :], H[:, col_index]) got = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): got.append((i, j, predict(W, H, i, j))) print(got) .. rst-class:: sphx-glr-script-out .. code-block:: none [(0, 0, 1.8940507356352687), (0, 1, 0.10912372262184848), (0, 2, 0.3072453141962623), (0, 3, 0.3072453141962623), (1, 0, 1.014673742790866), (1, 1, 0.9848866016414943), (1, 2, 0.0), (1, 3, 0.0), (2, 0, 1.1066115912409111), (2, 1, 0.0), (2, 2, 0.19083752823558645), (2, 3, 0.19083752823558645), (3, 0, 1.1066115912409111), (3, 1, 0.0), (3, 2, 0.19083752823558645), (3, 3, 0.19083752823558645), (4, 0, 0.9470253678176344), (4, 1, 0.05456186131092424), (4, 2, 0.15362265709813114), (4, 3, 0.15362265709813114)] .. GENERATED FROM PYTHON SOURCE LINES 77-87 Conversion into ONNX ++++++++++++++++++++ There is no implemented converter for `NMF `_ as the function we plan to convert is not transformer or a predictor. The following converter does not need to be registered, it just creates an ONNX graph equivalent to function *predict* implemented above. .. GENERATED FROM PYTHON SOURCE LINES 87-113 .. code-block:: default def nmf_to_onnx(W, H, op_version=12): """ The function converts a NMF described by matrices *W*, *H* (*WH* approximate training data *M*). into a function which takes two indices *(i, j)* and returns the predictions for it. It assumes these indices applies on the training data. """ col = OnnxArrayFeatureExtractor(H, "col") row = OnnxArrayFeatureExtractor(W.T, "row") dot = OnnxMul(col, row, op_version=op_version) res = OnnxReduceSum(dot, output_names="rec", op_version=op_version) indices_type = np.array([0], dtype=np.int64) onx = res.to_onnx( inputs={"col": indices_type, "row": indices_type}, outputs=[("rec", FloatTensorType((None, 1)))], target_opset=op_version, ) return onx model_onnx = nmf_to_onnx(W.astype(np.float32), H.astype(np.float32)) print(model_onnx) .. rst-class:: sphx-glr-script-out .. code-block:: none ir_version: 7 opset_import { domain: "" version: 12 } opset_import { domain: "ai.onnx.ml" version: 1 } producer_name: "skl2onnx" producer_version: "1.16.0" domain: "ai.onnx" model_version: 0 graph { node { input: "Ar_ArrayFeatureExtractorcst" input: "col" output: "Ar_Z0" name: "Ar_ArrayFeatureExtractor" op_type: "ArrayFeatureExtractor" domain: "ai.onnx.ml" } node { input: "Ar_ArrayFeatureExtractorcst1" input: "row" output: "Ar_Z02" name: "Ar_ArrayFeatureExtractor1" op_type: "ArrayFeatureExtractor" domain: "ai.onnx.ml" } node { input: "Ar_Z0" input: "Ar_Z02" output: "Mu_C0" name: "Mu_Mul" op_type: "Mul" domain: "" } node { input: "Mu_C0" output: "rec" name: "Re_ReduceSum" op_type: "ReduceSum" domain: "" } name: "OnnxReduceSum" initializer { dims: 2 dims: 4 data_type: 1 float_data: 1.98630548 float_data: 0 float_data: 0.342542619 float_data: 0.342542619 float_data: 0.900879145 float_data: 0.874432623 float_data: 0 float_data: 0 name: "Ar_ArrayFeatureExtractorcst" } initializer { dims: 2 dims: 5 data_type: 1 float_data: 0.896955 float_data: 0 float_data: 0.557120502 float_data: 0.557120502 float_data: 0.448477507 float_data: 0.124793746 float_data: 1.126315 float_data: 0 float_data: 0 float_data: 0.0623968728 name: "Ar_ArrayFeatureExtractorcst1" } input { name: "col" type { tensor_type { elem_type: 7 shape { dim { } } } } } input { name: "row" type { tensor_type { elem_type: 7 shape { dim { } } } } } output { name: "rec" type { tensor_type { elem_type: 1 shape { dim { } dim { dim_value: 1 } } } } } } .. GENERATED FROM PYTHON SOURCE LINES 114-115 Let's compute prediction with it. .. GENERATED FROM PYTHON SOURCE LINES 115-137 .. code-block:: default sess = InferenceSession( model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] ) def predict_onnx(sess, row_indices, col_indices): res = sess.run(None, {"col": col_indices, "row": row_indices}) return res onnx_preds = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): row_indices = np.array([i], dtype=np.int64) col_indices = np.array([j], dtype=np.int64) pred = predict_onnx(sess, row_indices, col_indices)[0] onnx_preds.append((i, j, pred[0, 0])) print(onnx_preds) .. rst-class:: sphx-glr-script-out .. code-block:: none [(0, 0, 1.8940508), (0, 1, 0.10912372), (0, 2, 0.3072453), (0, 3, 0.3072453), (1, 0, 1.0146737), (1, 1, 0.9848866), (1, 2, 0.0), (1, 3, 0.0), (2, 0, 1.1066115), (2, 1, 0.0), (2, 2, 0.19083752), (2, 3, 0.19083752), (3, 0, 1.1066115), (3, 1, 0.0), (3, 2, 0.19083752), (3, 3, 0.19083752), (4, 0, 0.9470254), (4, 1, 0.05456186), (4, 2, 0.15362266), (4, 3, 0.15362266)] .. GENERATED FROM PYTHON SOURCE LINES 138-139 The ONNX graph looks like the following. .. GENERATED FROM PYTHON SOURCE LINES 139-151 .. code-block:: default pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring"), ) pydot_graph.write_dot("graph_nmf.dot") os.system("dot -O -Tpng graph_nmf.dot") image = plt.imread("graph_nmf.dot.png") plt.imshow(image) plt.axis("off") .. image-sg:: /auto_examples/images/sphx_glr_plot_nmf_001.png :alt: plot nmf :srcset: /auto_examples/images/sphx_glr_plot_nmf_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 1654.5, 846.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 152-153 **Versions used for this example** .. GENERATED FROM PYTHON SOURCE LINES 153-159 .. code-block:: default print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", onnxruntime.__version__) print("skl2onnx: ", skl2onnx.__version__) .. rst-class:: sphx-glr-script-out .. code-block:: none numpy: 1.23.5 scikit-learn: 1.4.dev0 onnx: 1.15.0 onnxruntime: 1.16.0+cu118 skl2onnx: 1.16.0 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.298 seconds) .. _sphx_glr_download_auto_examples_plot_nmf.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_nmf.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_nmf.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_