Skip to content

Commit e2d3c8b

Browse files
authored
Add AMD GPU support (#1546)
* add amd gpu tests * add docs * add docs * add docs * Add ORT trainer docs and dockerfile * addressed comments * addressed comments * addressed comments * added pytorch installation step * update test * update --------- Co-authored-by: Mohit Sharma <mohit@huggingface.co>
1 parent 521d069 commit e2d3c8b

File tree

13 files changed

+832
-66
lines changed

13 files changed

+832
-66
lines changed

docs/source/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
title: How to accelerate training
3131
- local: onnxruntime/usage_guides/gpu
3232
title: Accelerated inference on NVIDIA GPUs
33+
- local: onnxruntime/usage_guides/amdgpu
34+
title: Accelerated inference on AMD GPUs
3335
title: How-to guides
3436
isExpanded: false
3537
- sections:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Accelerated inference on AMD GPUs supported by ROCm
2+
3+
By default, ONNX Runtime runs inference on CPU devices. However, it is possible to place supported operations on an AMD Instinct GPU, while leaving any unsupported ones on CPU. In most cases, this allows costly operations to be placed on GPU and significantly accelerate inference.
4+
5+
Our testing involved AMD Instinct GPUs, and for specific GPU compatibility, please refer to the official support list of GPUs available [here](https://rocm.docs.amd.com/en/latest/release/gpu_os_support.html).
6+
7+
This guide will show you how to run inference on the `ROCMExecutionProvider` execution provider that ONNX Runtime supports for AMD GPUs.
8+
9+
## Installation
10+
The following setup installs the ONNX Runtime support with ROCM Execution Provider with ROCm 5.7.
11+
12+
#### 1. ROCm Installation
13+
14+
To install ROCM 5.7, please follow the [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html).
15+
16+
#### 2. PyTorch Installation with ROCm Support
17+
Optimum ONNX Runtime integration relies on some functionalities of Transformers that require PyTorch. For now, we recommend to use Pytorch compiled against RoCm 5.7, that can be installed following [PyTorch installation guide](https://pytorch.org/get-started/locally/):
18+
19+
```bash
20+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
21+
```
22+
23+
<Tip>
24+
For docker installation, the following base image is recommended: `rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1`
25+
</Tip>
26+
27+
### 3. ONNX Runtime installation with ROCm Execution Provider
28+
29+
```bash
30+
# pre-requisites
31+
pip install -U pip
32+
pip install cmake onnx
33+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
34+
35+
# Install ONNXRuntime from source
36+
git clone --recursive https://github.com/ROCmSoftwarePlatform/onnxruntime.git
37+
git checkout rocm5.7_internal_testing_eigen-3.4.zip_hash
38+
cd onnxruntime
39+
40+
./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm
41+
pip install build/Linux/Release/dist/*
42+
```
43+
44+
<Tip>
45+
To avoid conflicts between `onnxruntime` and `onnxruntime-rocm`, make sure the package `onnxruntime` is not installed by running `pip uninstall onnxruntime` prior to installing `onnxruntime-rocm`.
46+
</Tip>
47+
48+
### Checking the ROCm installation is successful
49+
50+
Before going further, run the following sample code to check whether the install was successful:
51+
52+
```python
53+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
54+
>>> from transformers import AutoTokenizer
55+
56+
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
57+
... "philschmid/tiny-bert-sst2-distilled",
58+
... export=True,
59+
... provider="ROCMExecutionProvider",
60+
... )
61+
62+
>>> tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
63+
>>> inputs = tokenizer("expectations were low, actual enjoyment was high", return_tensors="pt", padding=True)
64+
65+
>>> outputs = ort_model(**inputs)
66+
>>> assert ort_model.providers == ["ROCMExecutionProvider", "CPUExecutionProvider"]
67+
```
68+
69+
In case this code runs gracefully, congratulations, the installation is successfull! If you encounter the following error or similar,
70+
71+
```
72+
ValueError: Asked to use ROCMExecutionProvider as an ONNX Runtime execution provider, but the available execution providers are ['CPUExecutionProvider'].
73+
```
74+
75+
then something is wrong with the ROCM or ONNX Runtime installation.
76+
77+
### Use ROCM Execution Provider with ORT models
78+
79+
For ORT models, the use is straightforward. Simply specify the `provider` argument in the `ORTModel.from_pretrained()` method. Here's an example:
80+
81+
```python
82+
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
83+
84+
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
85+
... "distilbert-base-uncased-finetuned-sst-2-english",
86+
... export=True,
87+
... provider="ROCMExecutionProvider",
88+
... )
89+
```
90+
91+
The model can then be used with the common 🤗 Transformers API for inference and evaluation, such as [pipelines](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/pipelines).
92+
When using Transformers pipeline, note that the `device` argument should be set to perform pre- and post-processing on GPU, following the example below:
93+
94+
```python
95+
>>> from optimum.pipelines import pipeline
96+
>>> from transformers import AutoTokenizer
97+
98+
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
99+
100+
>>> pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
101+
>>> result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
102+
>>> print(result) # doctest: +IGNORE_RESULT
103+
# printing: [{'label': 'POSITIVE', 'score': 0.9997727274894c714}]
104+
```
105+
106+
Additionally, you can pass the session option `log_severity_level = 0` (verbose), to check whether all nodes are indeed placed on the ROCM execution provider or not:
107+
108+
```python
109+
>>> import onnxruntime
110+
111+
>>> session_options = onnxruntime.SessionOptions()
112+
>>> session_options.log_severity_level = 0
113+
114+
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
115+
... "distilbert-base-uncased-finetuned-sst-2-english",
116+
... export=True,
117+
... provider="ROCMExecutionProvider",
118+
... session_options=session_options
119+
... )
120+
```
121+
122+
### Observed time gains
123+
124+
Coming soon!

docs/source/onnxruntime/usage_guides/trainer.mdx

+28
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ To use `ORTTrainer` or `ORTSeq2SeqTrainer`, you need to install ONNX Runtime Tra
5656
To set up the environment, we __strongly recommend__ you install the dependencies with Docker to ensure that the versions are correct and well
5757
configured. You can find dockerfiles with various combinations [here](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training/docker).
5858

59+
#### Setup for NVIDIA GPU
60+
5961
Here below we take the installation of `onnxruntime-training 1.14.0` as an example:
6062

6163
* If you want to install `onnxruntime-training 1.14.0` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort1.14.0-cu116):
@@ -80,6 +82,32 @@ And run post-installation configuration:
8082
python -m torch_ort.configure
8183
```
8284

85+
#### Setup for AMD GPU
86+
87+
Here below we take the installation of `onnxruntime-training` nightly as an example:
88+
89+
* If you want to install `onnxruntime-training` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort-nightly-rocm57):
90+
91+
```bash
92+
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
93+
```
94+
95+
* If you want to install the dependencies beyond in a local Python environment. You can pip install them once you have [ROCM 5.7](https://rocmdocs.amd.com/en/latest/deploy/linux/quick_start.html) well installed.
96+
97+
```bash
98+
pip install onnx ninja
99+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
100+
pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
101+
pip install torch-ort
102+
pip install --upgrade protobuf==3.20.2
103+
```
104+
105+
And run post-installation configuration:
106+
107+
```bash
108+
python -m torch_ort.configure
109+
```
110+
83111
### Install Optimum
84112

85113
You can install Optimum via pypi:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Use rocm image
2+
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
3+
CMD rocm-smi
4+
5+
# Ignore interactive questions during `docker build`
6+
ENV DEBIAN_FRONTEND noninteractive
7+
8+
# Versions
9+
# available options 3.10
10+
ARG PYTHON_VERSION=3.10
11+
12+
# Bash shell
13+
RUN chsh -s /bin/bash
14+
SHELL ["/bin/bash", "-c"]
15+
16+
# Install and update tools to minimize security vulnerabilities
17+
RUN apt-get update
18+
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake \
19+
bzip2 ca-certificates libglib2.0-0 libxext6 libsm6 libxrender1 mercurial subversion libopenmpi-dev ffmpeg && \
20+
apt-get clean
21+
RUN apt-get autoremove -y
22+
23+
ARG PYTHON_EXE=/opt/conda/envs/py_$PYTHON_VERSION/bin/python
24+
25+
# (Optional) Intall test dependencies
26+
RUN $PYTHON_EXE -m pip install -U pip
27+
RUN $PYTHON_EXE -m pip install git+https://github.com/huggingface/transformers
28+
RUN $PYTHON_EXE -m pip install datasets accelerate evaluate coloredlogs absl-py rouge_score seqeval scipy sacrebleu nltk scikit-learn parameterized sentencepiece --no-cache-dir
29+
RUN $PYTHON_EXE -m pip install deepspeed --no-cache-dir
30+
RUN conda install -y mpi4py
31+
32+
# PyTorch
33+
RUN $PYTHON_EXE -m pip install onnx ninja
34+
35+
# ORT Module
36+
RUN $PYTHON_EXE -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
37+
RUN $PYTHON_EXE -m pip install torch-ort
38+
RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
39+
RUN $PYTHON_EXE -m torch_ort.configure
40+
41+
WORKDIR .
42+
43+
CMD ["/bin/bash"]

optimum/onnxruntime/modeling_ort.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -308,16 +308,20 @@ def to(self, device: Union[torch.device, str, int]):
308308
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
309309
return self
310310

311-
if device.type == "cuda" and self._use_io_binding is False:
311+
self.device = device
312+
provider = get_provider_for_device(self.device)
313+
validate_provider_availability(provider) # raise error if the provider is not available
314+
315+
# IOBinding is only supported for CPU and CUDA Execution Providers.
316+
if device.type == "cuda" and self._use_io_binding is False and provider == "CUDAExecutionProvider":
312317
self.use_io_binding = True
313318
logger.info(
314319
"use_io_binding was set to False, setting it to True because it can provide a huge speedup on GPUs. "
315320
"It is possible to disable this feature manually by setting the use_io_binding attribute back to False."
316321
)
317322

318-
self.device = device
319-
provider = get_provider_for_device(self.device)
320-
validate_provider_availability(provider) # raise error if the provider is not available
323+
if provider == "ROCMExecutionProvider":
324+
self.use_io_binding = False
321325

322326
self.model.set_providers([provider], provider_options=[provider_options])
323327
self.providers = self.model.get_providers()

optimum/onnxruntime/utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def _is_gpu_available():
6363
Checks if a gpu is available.
6464
"""
6565
available_providers = ort.get_available_providers()
66-
if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available():
66+
if (
67+
"CUDAExecutionProvider" in available_providers or "ROCMExecutionProvider" in available_providers
68+
) and torch.cuda.is_available():
6769
return True
6870
else:
6971
return False
@@ -184,7 +186,7 @@ def get_device_for_provider(provider: str, provider_options: Dict) -> torch.devi
184186
"""
185187
Gets the PyTorch device (CPU/CUDA) associated with an ONNX Runtime provider.
186188
"""
187-
if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
189+
if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider", "ROCMExecutionProvider"]:
188190
return torch.device(f"cuda:{provider_options['device_id']}")
189191
else:
190192
return torch.device("cpu")
@@ -194,7 +196,12 @@ def get_provider_for_device(device: torch.device) -> str:
194196
"""
195197
Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
196198
"""
197-
return "CUDAExecutionProvider" if device.type.lower() == "cuda" else "CPUExecutionProvider"
199+
if device.type.lower() == "cuda":
200+
if "ROCMExecutionProvider" in ort.get_available_providers():
201+
return "ROCMExecutionProvider"
202+
else:
203+
return "CUDAExecutionProvider"
204+
return "CPUExecutionProvider"
198205

199206

200207
def parse_device(device: Union[torch.device, str, int]) -> Tuple[torch.device, Dict]:

optimum/utils/testing_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ def require_torch_gpu(test_case):
6969
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
7070

7171

72+
def require_ort_rocm(test_case):
73+
"""Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime."""
74+
import onnxruntime as ort
75+
76+
providers = ort.get_available_providers()
77+
78+
return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")(
79+
test_case
80+
)
81+
82+
7283
def require_hf_token(test_case):
7384
"""
7485
Decorator marking a test that requires huggingface hub token.

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ known-first-party = ["optimum"]
3333
[tool.pytest.ini_options]
3434
markers = [
3535
"gpu_test",
36+
"cuda_ep_test",
37+
"trt_ep_test",
38+
"rocm_ep_test",
3639
"tensorflow_test",
3740
"timm_test",
3841
"run_in_series",

tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ COPY . /workspace/optimum
2323
RUN pip install /workspace/optimum[onnxruntime-gpu,tests]
2424

2525
ENV TEST_LEVEL=1
26-
CMD pytest onnxruntime/test_*.py --durations=0 -s -vvvvv -m gpu_test
26+
CMD pytest onnxruntime/test_*.py --durations=0 -s -vvvvv -m cuda_ep_test -m trt_ep_test

0 commit comments

Comments
 (0)