Skip to content

Commit a31e59e

Browse files
authored
Add ORT inference (#113)
* added gpu extras and added > transformers for token-classification pipeline issue * added numpy and huggingface hub to required packages * added modeling_* classes * adding tests and pipelines * remove vs code folder * added test model and adjusted gitignore * add readme for tests * working tests * added some documentation * will ci run? * added real model checkpoints * test ci * fix styling * fix some documentation * more doc fixes * added some feedback and wording from michael and lewis * renamed model class to ORTModelForXX * moved from_transformers to from_pretrained * applied ellas feedback * make style * first version of ORTModelForCausalLM without past-keys * added first draft of new .optimize method * added better quantize method * fix import * remove optimize and quantize * added lewis feedback * added style for test * added >>> to code snippets * style * added condition for staging tests * feedback morgan & michael * added action * forgot to install pytest * forgot sentence piece * made sure we won't have import conflicts * make style happy
1 parent 7417202 commit a31e59e

20 files changed

+2036
-10
lines changed
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
name: Onnxruntime Models (Inference) / Python - Test
4+
5+
on:
6+
push:
7+
branches: [ main ]
8+
pull_request:
9+
branches: [ main ]
10+
11+
jobs:
12+
build:
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: [3.8, 3.9]
17+
os: [ubuntu-20.04 ] #, windows-2019, macos-10.15]
18+
19+
runs-on: ${{ matrix.os }}
20+
steps:
21+
- uses: actions/checkout@v2
22+
- name: Setup Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v2
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
- name: Install dependencies
27+
run: |
28+
pip install .[tests,onnxruntime]
29+
- name: Test with pytest
30+
shell: bash
31+
run: |
32+
pytest tests/onnxruntime/test_modeling_ort.py

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,7 @@ dmypy.json
131131

132132
# Models
133133
*.onnx
134+
# include small test model for tests
135+
!tests/assets/onnx/model.onnx
136+
137+
.vscode

docs/source/_toctree.yml

+4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
title: 🤗 Optimum
44
- local: quickstart
55
title: Quickstart
6+
- local: pipelines
7+
title: Pipelines for inference
68
title: Get started
79
- sections:
10+
- local: onnxruntime/modeling_ort
11+
title: Inference
812
- local: onnxruntime/configuration
913
title: Configuration
1014
- local: onnxruntime/optimization
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Optimum Inference with ONNX Runtime
14+
15+
Optimum is a utility package for building and running inference with accelerated runtime like ONNX Runtime.
16+
Optimum can be used to load optimized models from the [Hugging Face Hub](hf.co/models) and create pipelines
17+
to run accelerated inference without rewriting your APIs.
18+
19+
## Switching from Transformers to Optimum Inference
20+
21+
The Optimum Inference models are API compatible with Hugging Face Transformers models. This means you can just replace your `AutoModelForXxx` class with the corresponding `ORTModelForXxx` class in `optimum`. For example, this is how you can use a question answering model in `optimum`:
22+
23+
```diff
24+
from transformers import AutoTokenizer, pipeline
25+
-from transformers import AutoModelForQuestionAnswering
26+
+from optimum.onnxruntime import ORTModelForQuestionAnswering
27+
28+
-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") # pytorch checkpoint
29+
+model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2") # onnx checkpoint
30+
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
31+
32+
onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer)
33+
34+
question = "What's my name?"
35+
context = "My name is Philipp and I live in Nuremberg."
36+
pred = onnx_qa(question, context)
37+
```
38+
39+
Optimum Inference also includes methods to convert vanilla Transformers models to optimized ones. Simply pass `from_transformers=True` to the `from_pretrained()` method, and your model will be loaded and converted to ONNX on-the-fly:
40+
41+
```python
42+
>>> from transformers import AutoTokenizer, pipeline
43+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
44+
45+
# load model from hub and convert
46+
>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True)
47+
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
48+
49+
# create pipeline
50+
>>> onnx_classifier = pipeline("text-classification",model=model,tokenizer=tokenizer)
51+
52+
>>> result = onnx_classifier(text="This is a great model")
53+
[{'label': 'POSITIVE', 'score': 0.9998838901519775}]
54+
```
55+
56+
You can find a complete walkhrough Optimum Inference for ONNX Runtime in this [notebook](xx).
57+
58+
### Working with the [Hugging Face Model Hub](https://hf.co/models)
59+
60+
The Optimum model classes like [`~ORTModelForSequenceClassification`] are integrated with the [Hugging Face Model Hub](https://hf.co/models)), which means you can not only
61+
load model from the Hub, but also push your models to the Hub with `push_to_hub()` method. Below is an example which downloads a vanilla Transformers model
62+
from the Hub and converts it to an optimum onnxruntime model and pushes it back into a new repository.
63+
64+
<!-- TODO: Add Quantizer into example when UX improved -->
65+
```python
66+
>>> from transformers import AutoTokenizer
67+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
68+
69+
# load model from hub and convert
70+
>>> model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english",from_transformers=True)
71+
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
72+
73+
# save converted model
74+
>>> model.save_pretrained("a_local_path_for_convert_onnx_model")
75+
>>> tokenizer.save_pretrained("a_local_path_for_convert_onnx_model")
76+
77+
# push model onnx model to HF Hub
78+
>>> model.push_to_hub("a_local_path_for_convert_onnx_model",
79+
repository_id="my-onnx-repo",
80+
use_auth_token=True
81+
)
82+
```
83+
84+
## ORTModel
85+
86+
[[autodoc]] onnxruntime.modeling_ort.ORTModel
87+
88+
## ORTModelForFeatureExtraction
89+
90+
[[autodoc]] onnxruntime.modeling_ort.ORTModelForFeatureExtraction
91+
92+
## ORTModelForQuestionAnswering
93+
94+
[[autodoc]] onnxruntime.modeling_ort.ORTModelForQuestionAnswering
95+
96+
## ORTModelForSequenceClassification
97+
98+
[[autodoc]] onnxruntime.modeling_ort.ORTModelForSequenceClassification
99+
100+
## ORTModelForTokenClassification
101+
102+
[[autodoc]] onnxruntime.modeling_ort.ORTModelForTokenClassification
103+

docs/source/pipelines.mdx

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Optimum pipelines for inference
14+
15+
The [`pipeline`] makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification.
16+
Even if you don't have experience with a specific modality or understand the code powering the models, you can still use them with the [`pipeline`]! This tutorial will teach you to:
17+
18+
<Tip>
19+
20+
You can also use the `pipeline()` function from Transformers and provide your `OptimumModel`.
21+
22+
</Tip>
23+
24+
Currenlty supported tasks are:
25+
26+
**Onnx Runtime**
27+
28+
* `feature-extraction`
29+
* `text-classification`
30+
* `token-classification`
31+
* `question-answering`
32+
* `zero-shot-classification`
33+
* `text-generation`
34+
35+
## Optimum pipeline usage
36+
37+
While each task has an associated [~`pipeline`], which it is simpler to use the general [~`pipeline`] abstraction which contains all the specific task pipelines.
38+
The [~`pipeline`] automatically loads a default model and tokenizer capable of inference for your task.
39+
40+
1. Start by creating a [~`pipeline`] and specify an inference task:
41+
42+
```python
43+
>>> from optimum import pipeline
44+
45+
>>> classifier = pipeline(task="text-classification", accelerator="ort")
46+
47+
```
48+
49+
2. Pass your input text to the [~`pipeline`]:
50+
51+
```python
52+
>>> classifier("I like you. I love you.")
53+
[{'label': 'POSITIVE', 'score': 0.9998838901519775}]
54+
```
55+
56+
_Note: The default models used in the [~`pipeline`] are not optimized or quantized, there won't be an performance improvement compared to there pytorch counter parts._
57+
58+
### Using vanilla Transformers model and converting to ONNX
59+
60+
The [`pipeline`] accepts any supported model from the [Model Hub](https://huggingface.co/models).
61+
There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task.
62+
Once you've picked an appropriate model, load it with the `from_pretrained("{model_id}",from_transformers=True)` method associated with the `ORTModelFor*`
63+
[`AutoTokenizer'] class. For example, here's how you can load the [`ORTModelForQuestionAnswering`] class for question answering:
64+
65+
```python
66+
>>> from transformers import AutoTokenizer
67+
>>> from optimum.onnxruntime import ORTModelForQuestionAnswering
68+
>>> from optimum import pipeline
69+
70+
>>> tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
71+
>>> # loading the pytorch checkpoint and converting to ORT format by providing the from_transformers=True parameter
72+
>>> model = ORTModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2",from_transformers=True)
73+
74+
>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
75+
>>> question = "What's my name?"
76+
>>> context = "My name is Philipp and I live in Nuremberg."
77+
78+
>>> pred = onnx_qa(question=question, context=context)
79+
```
80+
81+
### Using Optimum models
82+
83+
The [`pipeline`] is tightly integrated with [Model Hub](https://huggingface.co/models) and can load optimized models directly, e.g. those created with OnnxRuntime.
84+
There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task.
85+
Once you've picked an appropriate model, load it with the `from_pretrained()` method associated with the corresponding `ORTModelFor*`
86+
and [`AutoTokenizer'] class. For example, here's how you can load an optimized model for question answering:
87+
88+
```python
89+
>>> from transformers import AutoTokenizer
90+
>>> from optimum.onnxruntime import ORTModelForQuestionAnswering
91+
>>> from optimum import pipeline
92+
93+
>>> tokenizer = AutoTokenizer.from_pretrained("optimum/roberta-base-squad2")
94+
>>> # loading already converted and optimized ORT checkpoint for inference
95+
>>> model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2")
96+
97+
>>> onnx_qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
98+
>>> question = "What's my name?"
99+
>>> context = "My name is Philipp and I live in Nuremberg."
100+
101+
>>> pred = onnx_qa(question=question, context=context)
102+
```
103+
104+
105+
### Optimizing and Quantizing in Pipelines
106+
107+
The [`pipeline`] can not only run inference on vanilla Onnxruntime checkpoints you can also use checkpoints optimized with `ORTQuantizer` and `ORTOptimizer`
108+
Below you can find two examples on how you could [~`ORTOptimizer`] and [~`ORTQuantizer`] to optimize/quantize your model and use it for inference afterwards.
109+
110+
### Quantizing with [~`ORTQuantizer`]
111+
112+
```python
113+
>>> from pathlib import Path
114+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
115+
>>> from optimum.onnxruntime.configuration import AutoQuantizationConfig
116+
>>> from optimum.pipelines import pipeline
117+
>>> from transformers import AutoTokenizer
118+
119+
# define model_id and load tokenizer
120+
>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
121+
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
122+
>>> save_path = Path("optimum_model")
123+
>>> save_path.mkdir(exist_ok=True)
124+
125+
# use ORTQuantizer to export the model and define quantization configuration
126+
>>> quantizer = ORTQuantizer.from_pretrained(model_id, feature="sequence-classification")
127+
>>> qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=True)
128+
129+
# apply the quantization configuration to the model
130+
>>> quantizer.export(
131+
onnx_model_path=save_path / "model.onnx",
132+
onnx_quantized_model_output_path=save_path / "model-quantized.onnx",
133+
quantization_config=qconfig,
134+
)
135+
>>> quantizer.model.config.save_pretrained(save_path) # saves config.json
136+
137+
# load optimized model from local path or repository
138+
>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-quantized.onnx")
139+
140+
# create transformers pipeline
141+
>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer)
142+
>>> text = "I like the new ORT pipeline"
143+
>>> pred = onnx_clx(text)
144+
>>> print(pred)
145+
146+
# save model & push model to the hub
147+
>>> tokenizer.save_pretrained("new_path_for_directory")
148+
>>> model.save_pretrained("new_path_for_directory")
149+
>>> model.push_to_hub("new_path_for_directory",
150+
repository_id="my-onnx-repo",
151+
use_auth_token=True
152+
)
153+
```
154+
155+
### Optimizing with [~`ORTOptimizer`]
156+
157+
```python
158+
>>> from pathlib import Path
159+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer
160+
>>> from optimum.onnxruntime.configuration import OptimizationConfig
161+
>>> from optimum.pipelines import pipeline
162+
163+
# define model_id and load tokenizer
164+
>>> model_id = "distilbert-base-uncased-finetuned-sst-2-english"
165+
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
166+
>>> save_path = Path("optimum_model")
167+
>>> save_path.mkdir(exist_ok=True)
168+
169+
# use ORTOptimizer to export the model and define quantization configuration
170+
>>> optimizer = ORTOptimizer.from_pretrained(model_id, feature="sequence-classification")
171+
>>> optimization_config = OptimizationConfig(optimization_level=2)
172+
173+
# apply the optimization configuration to the model
174+
>>> optimizer.export(
175+
onnx_model_path=save_path / "model.onnx",
176+
onnx_optimized_model_output_path=save_path / "model-optimized.onnx",
177+
optimization_config=optimization_config,
178+
)
179+
>>> optimizer.model.config.save_pretrained(save_path) # saves config.json
180+
181+
# load optimized model from local path or repository
182+
>>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-optimized.onnx")
183+
184+
# create transformers pipeline
185+
>>> onnx_clx = pipeline("text-classification", model=model, tokenizer=tokenizer)
186+
>>> text = "I like the new ORT pipeline"
187+
>>> pred = onnx_clx(text)
188+
>>> print(pred)
189+
190+
# save model & push model to the hub
191+
>>> tokenizer.save_pretrained("new_path_for_directory")
192+
>>> model.save_pretrained("new_path_for_directory")
193+
>>> model.push_to_hub("new_path_for_directory",
194+
repository_id="my-onnx-repo",
195+
use_auth_token=True)
196+
```
197+
198+
## Transformers pipeline usage
199+
200+
The [`pipeline`] is just a light wrapper around the `transformers.pipeline` function to enable checks for supported tasks and additional features
201+
, like quantization and optimization. This being said you can use the `transformers.pipeline` and just replace your `AutoFor*` with the optimum
202+
`ORTModelFor*` class.
203+
204+
```diff
205+
from transformers import AutoTokenizer, pipeline
206+
-from transformers import AutoModelForQuestionAnswering
207+
+from optimum.onnxruntime import ORTModelForQuestionAnswering
208+
209+
-model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
210+
+model = ORTModelForQuestionAnswering.from_transformers("optimum/roberta-base-squad2")
211+
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
212+
213+
onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer)
214+
215+
question = "What's my name?"
216+
context = "My name is Philipp and I live in Nuremberg."
217+
pred = onnx_qa(question, context)
218+
```

0 commit comments

Comments
 (0)