.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_tfidfvectorizer.py: .. _l-example-tfidfvectorizer: TfIdfVectorizer with ONNX ========================= This example is inspired from the following example: `Column Transformer with Heterogeneous Data Sources `_ which builds a pipeline to classify text. .. contents:: :local: Train a pipeline with TfidfVectorizer +++++++++++++++++++++++++++++++++++++ It replicates the same pipeline taken from *scikit-learn* documentation but reduces it to the part ONNX actually supports without implementing a custom converter. Let's get the data. .. code-block:: default import matplotlib.pyplot as plt import os from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnxruntime as rt from skl2onnx.common.data_types import StringTensorType from skl2onnx import convert_sklearn import numpy as np from sklearn.base import BaseEstimator, TransformerMixin from sklearn.datasets import fetch_20newsgroups try: from sklearn.datasets._twenty_newsgroups import ( strip_newsgroup_footer, strip_newsgroup_quoting) except ImportError: # scikit-learn < 0.24 from sklearn.datasets.twenty_newsgroups import ( strip_newsgroup_footer, strip_newsgroup_quoting) from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer from sklearn.metrics import classification_report from sklearn.linear_model import LogisticRegression # limit the list of categories to make running this example faster. categories = ['alt.atheism', 'talk.religion.misc'] train = fetch_20newsgroups(random_state=1, subset='train', categories=categories, ) test = fetch_20newsgroups(random_state=1, subset='test', categories=categories, ) The first transform extract two fields from the data. We take it out form the pipeline and assume the data is defined by two text columns. .. code-block:: default class SubjectBodyExtractor(BaseEstimator, TransformerMixin): """Extract the subject & body from a usenet post in a single pass. Takes a sequence of strings and produces a dict of sequences. Keys are `subject` and `body`. """ def fit(self, x, y=None): return self def transform(self, posts): # construct object dtype array with two columns # first column = 'subject' and second column = 'body' features = np.empty(shape=(len(posts), 2), dtype=object) for i, text in enumerate(posts): headers, _, bod = text.partition('\n\n') bod = strip_newsgroup_footer(bod) bod = strip_newsgroup_quoting(bod) features[i, 1] = bod prefix = 'Subject:' sub = '' for line in headers.split('\n'): if line.startswith(prefix): sub = line[len(prefix):] break features[i, 0] = sub return features train_data = SubjectBodyExtractor().fit_transform(train.data) test_data = SubjectBodyExtractor().fit_transform(test.data) The pipeline is almost the same except we remove the custom features. .. code-block:: default pipeline = Pipeline([ ('union', ColumnTransformer( [ ('subject', TfidfVectorizer(min_df=50), 0), ('body_bow', Pipeline([ ('tfidf', TfidfVectorizer()), ('best', TruncatedSVD(n_components=50)), ]), 1), # Removed from the original example as # it requires a custom converter. # ('body_stats', Pipeline([ # ('stats', TextStats()), # returns a list of dicts # ('vect', DictVectorizer()), # list of dicts -> feature matrix # ]), 1), ], transformer_weights={ 'subject': 0.8, 'body_bow': 0.5, # 'body_stats': 1.0, } )), # Use a LogisticRegression classifier on the combined features. # Instead of LinearSVC (not fully ready in onnxruntime). ('logreg', LogisticRegression()), ]) pipeline.fit(train_data, train.target) print(classification_report(pipeline.predict(test_data), test.target)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none precision recall f1-score support 0 0.69 0.78 0.73 285 1 0.75 0.66 0.70 285 accuracy 0.72 570 macro avg 0.72 0.72 0.71 570 weighted avg 0.72 0.72 0.71 570 ONNX conversion +++++++++++++++ It is difficult to replicate the exact same tokenizer behaviour if the tokeniser comes from space, gensim or nltk. The default one used by *scikit-learn* uses regular expressions and is currently being implementing. The current implementation only considers a list of separators which can is defined in variable *seps*. .. code-block:: default seps = { TfidfVectorizer: { "separators": [ ' ', '.', '\\?', ',', ';', ':', '!', '\\(', '\\)', '\n', '"', "'", "-", "\\[", "\\]", "@" ] } } model_onnx = convert_sklearn( pipeline, "tfidf", initial_types=[("input", StringTensorType([None, 2]))], options=seps, target_opset=12) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /home/dupre/github_xadupre/sklearn-onnx/skl2onnx/common/_container.py:591: UserWarning: Unable to find operator 'Tokenizer' in domain 'com.microsoft' in ONNX, op_version is forced to 1. op_type, domain)) And save. .. code-block:: default with open("pipeline_tfidf.onnx", "wb") as f: f.write(model_onnx.SerializeToString()) Predictions with onnxruntime. .. code-block:: default sess = rt.InferenceSession("pipeline_tfidf.onnx") print('---', train_data[0]) inputs = {'input': train_data[:1]} pred_onx = sess.run(None, inputs) print("predict", pred_onx[0]) print("predict_proba", pred_onx[1]) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none --- [" Re: Jews can't hide from keith@cco." 'Deletions...\n\nSo, you consider the german poster\'s remark anti-semitic? Perhaps you\nimply that anyone in Germany who doesn\'t agree with israely policy in a\nnazi? Pray tell, how does it even qualify as "casual anti-semitism"? \nIf the term doesn\'t apply, why then bring it up?\n\nYour own bigotry is shining through. \n-- '] predict [1] predict_proba [{0: 0.4389554560184479, 1: 0.5610445141792297}] With *scikit-learn*: .. code-block:: default print(pipeline.predict(train_data[:1])) print(pipeline.predict_proba(train_data[:1])) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [0] [[0.71418415 0.28581585]] There are discrepencies for this model because the tokenization is not exactly the same. This is a work in progress. Display the ONNX graph ++++++++++++++++++++++ Finally, let's see the graph converted with *sklearn-onnx*. .. code-block:: default pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring", color="yellow", fillcolor="yellow", style="filled")) pydot_graph.write_dot("pipeline_tfidf.dot") os.system('dot -O -Gdpi=300 -Tpng pipeline_tfidf.dot') image = plt.imread("pipeline_tfidf.dot.png") fig, ax = plt.subplots(figsize=(40, 20)) ax.imshow(image) ax.axis('off') .. image:: /auto_examples/images/sphx_glr_plot_tfidfvectorizer_001.png :alt: plot tfidfvectorizer :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (-0.5, 6614.5, 6723.5, -0.5) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.513 seconds) .. _sphx_glr_download_auto_examples_plot_tfidfvectorizer.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: /../../cus/lib/python3.7/site-packages/sphinx_gallery/_static/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/microsoft/skl2onnx/master?urlpath=lab/tree/notebooks/auto_examples/plot_tfidfvectorizer.ipynb :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_tfidfvectorizer.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_tfidfvectorizer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_