Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ort CIs (slow, gpu, train) #2024

Merged
merged 73 commits into from
Jan 29, 2025
Merged
Changes from 54 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
17bc171
update ort CIs
IlyasMoutawwakil Sep 14, 2024
fbaa980
fix train ci
IlyasMoutawwakil Sep 14, 2024
90aa85d
fix gpu ci
IlyasMoutawwakil Sep 14, 2024
87c9f3e
gpus all
IlyasMoutawwakil Sep 14, 2024
0c1c6bd
devel
IlyasMoutawwakil Sep 14, 2024
430260e
enable trt
IlyasMoutawwakil Sep 15, 2024
00e51c7
fix
IlyasMoutawwakil Sep 15, 2024
3fc5486
fix
IlyasMoutawwakil Sep 15, 2024
8044232
fix
IlyasMoutawwakil Sep 16, 2024
2fd4d47
test
IlyasMoutawwakil Sep 27, 2024
1f322fc
rename
IlyasMoutawwakil Sep 27, 2024
6f7c599
change instance
IlyasMoutawwakil Sep 27, 2024
806faca
test
IlyasMoutawwakil Sep 27, 2024
3eecee6
use available
IlyasMoutawwakil Sep 28, 2024
ab62319
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Dec 10, 2024
1b7e652
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Jan 10, 2025
cebe6bf
update
IlyasMoutawwakil Jan 10, 2025
d0f62b0
shorter labels as well
IlyasMoutawwakil Jan 10, 2025
d001b9b
add onnxruntime-traning
IlyasMoutawwakil Jan 10, 2025
d271637
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Jan 13, 2025
a318c0a
fix onnxruntime package checking
IlyasMoutawwakil Jan 13, 2025
7597692
Merge branch 'enable-ort-gpu-tests' of https://github.com/huggingface…
IlyasMoutawwakil Jan 13, 2025
a6b3a8e
fix typo
IlyasMoutawwakil Jan 13, 2025
a5c76c4
fix typo
IlyasMoutawwakil Jan 13, 2025
745ad8d
remove torch version
IlyasMoutawwakil Jan 13, 2025
bb48c4d
fix trainer
IlyasMoutawwakil Jan 13, 2025
0518dfd
fixed trt ep by using trt docker image (the only way to make sure eve…
IlyasMoutawwakil Jan 13, 2025
9635ec4
latest trt version
IlyasMoutawwakil Jan 13, 2025
cb9cb7f
remove pkv speedup timing since never used
IlyasMoutawwakil Jan 13, 2025
eb25460
trust remote code for training datasets
IlyasMoutawwakil Jan 13, 2025
0a7a23d
remove rocm from diffusers tests
IlyasMoutawwakil Jan 13, 2025
64e9c86
move ort training tests to onnxruntime-training
IlyasMoutawwakil Jan 13, 2025
bbed6bc
fix ort training
IlyasMoutawwakil Jan 14, 2025
1334200
fix
IlyasMoutawwakil Jan 14, 2025
84bf7ee
style
IlyasMoutawwakil Jan 14, 2025
be10d26
always assert closenes and not equality
IlyasMoutawwakil Jan 14, 2025
7ba72a6
fixed perceiver
IlyasMoutawwakil Jan 14, 2025
eceba5b
fixed missing position ids when attn mask is given
IlyasMoutawwakil Jan 14, 2025
9150e05
remove num_labels from output shapes as it's not a dynamic axis
IlyasMoutawwakil Jan 14, 2025
198ce06
raise error on missing mandatory inputs
IlyasMoutawwakil Jan 14, 2025
930103f
added atol and rtol as part of the ORTModelTestMixin class
IlyasMoutawwakil Jan 14, 2025
49cfdc0
fix segformer image segmentation
IlyasMoutawwakil Jan 14, 2025
5b8efd4
style
IlyasMoutawwakil Jan 14, 2025
941484a
fix vision encoder io binding
IlyasMoutawwakil Jan 14, 2025
18e887d
hot fix io binding, remove its dependency to the order of inputs and …
IlyasMoutawwakil Jan 15, 2025
88a7e8b
fix
IlyasMoutawwakil Jan 15, 2025
e9abe6a
typo
IlyasMoutawwakil Jan 15, 2025
c9b45ee
unify io binding api with non io binding
IlyasMoutawwakil Jan 15, 2025
aad9aaf
force evaluated shape to int
IlyasMoutawwakil Jan 15, 2025
a29706e
mark pix2struct io binding tests
IlyasMoutawwakil Jan 15, 2025
821c997
force contiguity in forward pass
IlyasMoutawwakil Jan 16, 2025
cc2e124
fixed cryptic contiguity problems
IlyasMoutawwakil Jan 16, 2025
3a2bcee
fix some
IlyasMoutawwakil Jan 16, 2025
f0ea288
fix vision2seq modeling and testing
IlyasMoutawwakil Jan 16, 2025
7e122c0
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Jan 28, 2025
dc2361d
Update setup.py
IlyasMoutawwakil Jan 28, 2025
4eb95f1
update import utils
IlyasMoutawwakil Jan 28, 2025
7f1fc40
Update optimum/onnxruntime/modeling_ort.py
IlyasMoutawwakil Jan 28, 2025
696cc95
fix vision encoder decoder io binding
IlyasMoutawwakil Jan 28, 2025
1827450
enable bigbird and bigbirg pegasus and seperate timm slow tests to un…
IlyasMoutawwakil Jan 28, 2025
41abf7f
use bigger machine for slow tests
IlyasMoutawwakil Jan 28, 2025
6f3084a
lower atol and rtol for image classification logits
IlyasMoutawwakil Jan 28, 2025
010030e
fix
IlyasMoutawwakil Jan 28, 2025
445b291
large
IlyasMoutawwakil Jan 28, 2025
04c8904
enable more Longformer and MCTCT
IlyasMoutawwakil Jan 29, 2025
18e1844
enable commented models in export as well
IlyasMoutawwakil Jan 29, 2025
4487c74
uncomment timm slow models, big bird optimization and marian pkv comp…
IlyasMoutawwakil Jan 29, 2025
24d682e
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Jan 29, 2025
def5fdb
Merge branch 'main' into enable-ort-gpu-tests
IlyasMoutawwakil Jan 29, 2025
458355d
fix whisper/speech_to_text test and make convolution deterministic
IlyasMoutawwakil Jan 29, 2025
881015c
pin torch for ort training
IlyasMoutawwakil Jan 29, 2025
7c8c56f
ctc and speech also uses convolution so has to be deterministic
IlyasMoutawwakil Jan 29, 2025
3a4bac9
revert vison2seq atol
IlyasMoutawwakil Jan 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions .github/workflows/test_export_onnx_cli.yml
Original file line number Diff line number Diff line change
@@ -2,9 +2,11 @@ name: Exporters ONNX CLI / Python - Test

on:
push:
branches: [main]
branches:
- main
pull_request:
branches: [main]
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@@ -19,16 +21,20 @@ jobs:
os: [ubuntu-20.04]

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v2
- name: Checkout repository
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for pytorch export

- name: Install dependencies
run: |
pip install .[tests,exporters,diffusers]
- name: Test with unittest
working-directory: tests
- name: Test with pytest
run: |
pytest exporters/onnx/test_exporters_onnx_cli.py -n auto -m "not tensorflow_test and not timm_test" -s --durations=0
pytest tests/exporters/onnx/test_exporters_onnx_cli.py -n auto -m "not tensorflow_test and not timm_test" -s --durations=0
8 changes: 4 additions & 4 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: ONNX Runtime / Python - Test

on:
push:
branches: [main]
branches:
- main
pull_request:
branches: [main]
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
58 changes: 41 additions & 17 deletions .github/workflows/test_onnxruntime_gpu.yml
Original file line number Diff line number Diff line change
@@ -1,30 +1,54 @@
name: ONNX Runtime / Test GPU
name: ONNX Runtime GPU / Python - Test

on:
workflow_dispatch:
schedule:
- cron: 0 1 */3 * * # at 1am every 3 days
- cron: 0 7 * * * # every day at 7am UTC
pull_request:
types: [opened, synchronize, reopened, labeled]
# uncomment to enable on PR merge on main branch:
#push:
# branches:
# - main
branches:
- main
types:
- opened
- labeled
- reopened
- unlabeled
- synchronize

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
do-the-job:
if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || contains( github.event.pull_request.labels.*.name, 'gpu-test') }}
name: Start self-hosted EC2 runner
build:
if: ${{
(github.event_name == 'push') ||
(github.event_name == 'workflow_dispatch') ||
contains(github.event.pull_request.labels.*.name, 'gpu') ||
contains(github.event.pull_request.labels.*.name, 'onnxruntime-gpu')
}}

runs-on:
group: aws-g6-4xlarge-plus
env:
AWS_REGION: us-east-1

container:
image: nvcr.io/nvidia/tensorrt:24.12-py3
options: --gpus all

steps:
- name: Checkout
uses: actions/checkout@v2
- name: Build image
uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.9"

- name: Install dependencies
run: |
docker build -f tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu -t onnxruntime-gpu .
- name: Test with unittest within docker container
pip install --upgrade pip
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install .[tests,onnxruntime-gpu,diffusers]
- name: Test with pytest
run: |
docker run --rm --gpus all -v /mnt/cache/.cache/huggingface:/root/.cache/huggingface --workdir=/workspace/optimum/tests onnxruntime-gpu:latest
pytest tests/onnxruntime -m "cuda_ep_test or trt_ep_test" --durations=0 -vvvv -s -n auto
56 changes: 36 additions & 20 deletions .github/workflows/test_onnxruntime_slow.yml
Original file line number Diff line number Diff line change
@@ -1,33 +1,49 @@
name: ONNX Runtime slow / Python - Test
name: ONNX Runtime Slow / Python - Test

on:
workflow_dispatch:
schedule:
- cron: 0 7 * * * # every day at 7am
- cron: 0 7 * * * # every day at 7am UTC
pull_request:
branches:
- main
types:
- opened
- labeled
- reopened
- unlabeled
- synchronize

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ["3.9"]
os: [ubuntu-20.04]
if: ${{
(github.event_name == 'push') ||
(github.event_name == 'workflow_dispatch') ||
contains(github.event.pull_request.labels.*.name, 'slow') ||
contains(github.event.pull_request.labels.*.name, 'onnxruntime-slow')
}}

runs-on: ubuntu-20.04

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for export
run: |
pip install .[tests,onnxruntime,diffusers]
- name: Test with unittest
working-directory: tests
run: |
RUN_SLOW=1 pytest onnxruntime -s -m "run_slow" --durations=0
- name: Checkout
uses: actions/checkout@v4

- name: Setup Python 3.9
uses: actions/setup-python@v5
with:
python-version: "3.9"

- name: Install dependencies
run: |
pip install --upgrade pip
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[tests,onnxruntime,diffusers]
- name: Test with pytest
run: |
RUN_SLOW=1 pytest tests/onnxruntime -m "run_slow" --durations=0 -s -vvvv -n auto
26 changes: 0 additions & 26 deletions .github/workflows/test_onnxruntime_train.yml

This file was deleted.

66 changes: 66 additions & 0 deletions .github/workflows/test_onnxruntime_training.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: ONNX Runtime Training / Python - Test

on:
workflow_dispatch:
schedule:
- cron: 0 7 * * * # every day at 7am UTC
pull_request:
branches:
- main
types:
- opened
- labeled
- reopened
- unlabeled
- synchronize

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
if: ${{
(github.event_name == 'push') ||
(github.event_name == 'workflow_dispatch') ||
contains( github.event.pull_request.labels.*.name, 'training') ||
contains( github.event.pull_request.labels.*.name, 'onnxruntime-training')
}}

runs-on:
group: aws-g6-4xlarge-plus

container:
image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
options: --gpus all

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.9"

- name: Install dependencies
env:
TORCH_CUDA_ARCH_LIST: "5.0 6.0 7.0 7.5 8.0 8.6 9.0+PTX"
run: |
pip install --upgrade pip
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install --no-cache-dir torch-ort onnxruntime-training && python -m torch_ort.configure
pip install --no-cache-dir evaluate absl-py rouge_score seqeval sacrebleu nltk scikit-learn
pip install .[tests,onnxruntime-training]
- name: Test with pytest (trainer)
run: |
RUN_SLOW=1 pytest tests/onnxruntime-training/test_trainer.py --durations=0 -s -vvvv
env:
HF_DATASETS_TRUST_REMOTE_CODE: 1

- name: Test with pytest (examples)
run: |
RUN_SLOW=1 pytest tests/onnxruntime-training/test_examples.py --durations=0 -s -vvvv
env:
HF_DATASETS_TRUST_REMOTE_CODE: 1
Original file line number Diff line number Diff line change
@@ -333,6 +333,7 @@ def compute_metrics(p):
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
attn_implementation="eager",
)
image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
5 changes: 4 additions & 1 deletion examples/onnxruntime/training/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
@@ -442,9 +442,12 @@ def main():
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
attn_implementation="eager",
)
else:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_args.trust_remote_code, attn_implementation="eager"
)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")

5 changes: 4 additions & 1 deletion examples/onnxruntime/training/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
@@ -430,10 +430,13 @@ def main():
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
attn_implementation="eager",
)
else:
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
model = AutoModelForMaskedLM.from_config(
config, trust_remote_code=model_args.trust_remote_code, attn_implementation="eager"
)

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
1 change: 1 addition & 0 deletions examples/onnxruntime/training/question-answering/run_qa.py
Original file line number Diff line number Diff line change
@@ -364,6 +364,7 @@ def main():
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
attn_implementation="eager",
)

# Tokenizer check: this script requires a fast tokenizer.
Original file line number Diff line number Diff line change
@@ -458,6 +458,7 @@ def main():
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
attn_implementation="eager",
)

if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
Original file line number Diff line number Diff line change
@@ -527,6 +527,7 @@ def main():
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
attn_implementation="eager",
)
model.config.pad_token_id = model.config.eos_token_id

Original file line number Diff line number Diff line change
@@ -404,6 +404,7 @@ def main():
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
attn_implementation="eager",
)

# Preprocessing the raw_datasets
Original file line number Diff line number Diff line change
@@ -405,6 +405,7 @@ def get_label_list(labels):
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
attn_implementation="eager",
)

if tokenizer.pad_token is None:
Original file line number Diff line number Diff line change
@@ -408,6 +408,7 @@ def main():
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
attn_implementation="eager",
)

# Set decoder_start_token_id
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
@@ -154,7 +154,7 @@ class OnnxConfig(ExportConfig, ABC):
"feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 2: "height", 3: "width"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-image": OrderedDict(
{"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
22 changes: 21 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
@@ -825,8 +825,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs

if self.task == "feature-extraction":
common_outputs["last_hidden_state"] = {0: "batch_size"}

return common_outputs


@@ -978,7 +980,14 @@ class PoolFormerOnnxConfig(ViTOnnxConfig):


class SegformerOnnxConfig(YolosOnnxConfig):
pass
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = super().outputs

if self.task == "image-segmentation":
outputs["logits"] = {0: "batch_size"}

return outputs


class MobileNetV1OnnxConfig(ViTOnnxConfig):
@@ -1667,6 +1676,17 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"pixel_values": dynamic_axis,
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = super().outputs

if "logits" in outputs:
# default is {0: "batch_size", 1: "sequence_length"} where sequence_length is dynamic axis
# but perceiver always return the same max sequence length in the second dimension
outputs["logits"] = {0: "batch_size"}

return outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
self.is_generating_dummy_inputs = True
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
115 changes: 46 additions & 69 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging
from .utils import logging


logger = logging.get_logger(__name__)
@@ -38,6 +38,11 @@ class ORTModelPart:
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""

# should be in an ORTMixin
_prepare_io_binding = ORTModel._prepare_io_binding
_prepare_output_buffer = ORTModel._prepare_output_buffer
_output_shape_inference = ORTModel._output_shape_inference

_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

@@ -48,10 +53,12 @@ def __init__(self, session: InferenceSession, parent_model: "ORTModel"):

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}

self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)
self.input_shapes = {input_key.name: input_key.shape for input_key in session.get_inputs()}
self.output_shapes = {output_key.name: output_key.shape for output_key in session.get_outputs()}

@property
def device(self):
@@ -118,27 +125,26 @@ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor,
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = [input_ids]
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
else:
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = model_outputs["last_hidden_state"]

@@ -257,9 +263,7 @@ def forward(
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
use_cache_branch: None = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding

@@ -279,60 +283,45 @@ def forward(
input_ids, past_key_values, cache_position, use_torch=use_torch
)

model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

if self.parent_model.use_io_binding:
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
encoder_hidden_states,
use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None,
past_key_values=past_key_values,
)

outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)

# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g
model_inputs = [input_ids.contiguous()]

if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)

if "decoder_attention_mask" in self.input_names:
model_inputs.append(decoder_attention_mask)

if "encoder_attention_mask" in self.input_names:
model_inputs.append(encoder_attention_mask)

if past_key_values is not None:
model_inputs += past_key_values

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

if use_cache_branch_tensor is not None:
model_inputs.append(use_cache_branch_tensor)

if "cache_position" in self.input_names:
model_inputs.append(cache_position)

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.session,
*model_inputs,
model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
outputs_to_not_bind=outputs_to_not_bind,
)

if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
else:
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

# Set -1 for sequence_length as it could be larger than the real sequence_length
for name, shape in output_shapes.items():
if name in self.key_value_output_names:
output_shapes[name] = shape[:2] + (-1,) + shape[3:]

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = ()
@@ -382,21 +371,9 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
72 changes: 25 additions & 47 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
@@ -209,7 +209,6 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: bool = None,
**kwargs,
) -> CausalLMOutputWithPast:
@@ -218,8 +217,7 @@ def forward(
self.raise_on_numpy_input_io_binding(use_torch)

known_output_shapes = {}
use_cache_branch = None
loss = None

if self.use_cache:
if past_key_values is not None:
# Flatten the past_key_values (gpt_bigcode has fused key/value cache, so no need to flatten it)
@@ -233,35 +231,28 @@ def forward(
input_ids, past_key_values, use_torch
)

if self.use_io_binding:
# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
model_inputs.append(position_ids.contiguous())

if past_key_values is not None:
model_inputs += past_key_values
# Create position_ids on the fly for batch generation
if "position_ids" in self.input_names and position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

if use_cache_branch is not None:
model_inputs.append(use_cache_branch)
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
}

if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

io_binding, output_shapes, output_buffers = self.prepare_io_binding(
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
if self.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model, model_inputs, known_output_shapes=known_output_shapes
)

if self.device.type == "cpu":
@@ -271,32 +262,19 @@ def forward(
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

loss = output_buffers.get("loss", None)
logits = output_buffers["logits"].view(output_shapes["logits"])

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention)
past_key_values = tuple(
output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names
)

logits = output_buffers["logits"].view(output_shapes["logits"])

if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]
533 changes: 277 additions & 256 deletions optimum/onnxruntime/modeling_ort.py

Large diffs are not rendered by default.

185 changes: 56 additions & 129 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,6 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@@ -363,43 +362,28 @@ def forward(
use_torch = isinstance(input_features, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = (
[input_features, attention_mask] if "attention_mask" in self.input_names else [input_features]
)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
model_inputs = {
"input_features": input_features,
"attention_mask": attention_mask,
}

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_features": input_features.cpu().detach().numpy()}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
else:
onnx_inputs = {"input_features": input_features}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

# TODO: Replace with a better solution
# attention_mask is exported with int64 datatype and tokenizer produces int32 input
# for speech2text model. Hence, the input is type casted for inference.
if "attention_mask" in self.input_names:
if self.session.get_inputs()[1].type == "tensor(int64)":
onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64)

outputs = self.session.run(None, onnx_inputs)
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

@@ -422,60 +406,30 @@ def forward(
use_torch = isinstance(pixel_values, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
known_output_shapes = self.compute_encoder_known_output_shapes(pixel_values)
model_inputs = {
"pixel_values": pixel_values,
}

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
pixel_values,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
)
if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
else:
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"pixel_values": pixel_values.cpu().detach().numpy()}
else:
onnx_inputs = {"pixel_values": pixel_values}

outputs = self.session.run(None, onnx_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -> Dict[str, List[int]]:
if self.normalized_config.config.model_type == "donut-swin":
# TODO: kind of weird to export to ONNX with dynamic output shape if it is in fact static...
encoder_sequence_length = (
self.normalized_config.config.image_size[0]
* self.normalized_config.config.image_size[1]
// self.normalized_config.config.hidden_size
)
elif self.normalized_config.config.model_type in ["vit", "deit"]:
return None
else:
raise ValueError(
f"Unsupported encoder model type {self.normalized_config.config.model_type} for ORTForVisionSeq2Seq with IOBinding."
"Currently supported models are vit, donut-swin and deit."
"Please submit a PR to add support for this model type."
)

return {
"last_hidden_state": [
pixel_values.shape[0], # batch size
encoder_sequence_length,
self.normalized_config.config.hidden_size,
]
}


class ORTEncoderForPix2Struct(ORTEncoder):
"""
@@ -496,41 +450,28 @@ def forward(
use_torch = isinstance(flattened_patches, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = (
[flattened_patches, attention_mask] if "attention_mask" in self.input_names else [flattened_patches]
)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
model_inputs = {
"flattened_patches": flattened_patches,
"attention_mask": attention_mask,
}

io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"flattened_patches": flattened_patches.cpu().detach().numpy()}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
if self.device.type == "cpu":
self.session.run_with_iobinding(io_binding)
else:
onnx_inputs = {"flattened_patches": flattened_patches}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask

if "attention_mask" in self.input_names:
if self.session.get_inputs()[1].type == "tensor(int64)":
onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64)
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

outputs = self.session.run(None, onnx_inputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

@@ -1164,7 +1105,6 @@ def forward(
decoder_input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
@@ -1181,7 +1121,6 @@ def forward(
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
@@ -1297,7 +1236,6 @@ def forward(
decoder_input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Seq2SeqLMOutput:
@@ -1316,7 +1254,6 @@ def forward(
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
cache_position=cache_position,
labels=labels,
)

return Seq2SeqLMOutput(
@@ -1477,10 +1414,8 @@ def forward(
decoder_input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
if encoder_outputs is None:
encoder_outputs = self.encoder(pixel_values=pixel_values)

@@ -1489,17 +1424,18 @@ def forward(
if past_key_values is None or not self.use_cache or self.use_merged
else self.decoder_with_past
)

decoder_outputs = model(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
)

def prepare_inputs_for_generation(
@@ -1577,42 +1513,33 @@ def forward(
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
flattened_patches=flattened_patches,
attention_mask=attention_mask,
)

# TODO: for some reason the attention_mask for pix2struct is a float in transformers and not an int64. This messes up with the exporter
# hardcodes int64 input dtype for the attention mask. This workaround is quite ugly, it should be fixed rather in the ONNX exporter.
if isinstance(attention_mask, torch.Tensor):
attention_mask = attention_mask.to(torch.int64)
else:
attention_mask = attention_mask.astype(np.int64)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
if self.use_merged or not self.use_cache or past_key_values is None
else self.decoder_with_past
)

decoder_outputs = model(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
labels=labels,
)

return Seq2SeqLMOutput(
loss=decoder_outputs.get("loss", None),
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
)

def prepare_inputs_for_generation(
78 changes: 66 additions & 12 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
"""
The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime.
"""

import functools
import math
import os
@@ -27,8 +28,8 @@

# Integrations must be imported before ML frameworks:
# isort: off
import safetensors
from transformers.integrations import hp_params

from transformers.utils import is_accelerate_available
from packaging import version

@@ -58,7 +59,7 @@
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState
from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerState
from transformers.trainer_pt_utils import (
get_model_param_count,
get_module_class_from_name,
@@ -77,6 +78,8 @@
)
from transformers.training_args import ParallelMode
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_apex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
@@ -119,23 +122,24 @@

# Name of the files used for checkpointing
TRAINER_STATE_NAME = "trainer_state.json"
TRAINING_ARGS_NAME = "training_args.bin"

logger = logging.get_logger(__name__)


class ModuleWithLoss(nn.Module):
class ModuleWithLoss(PreTrainedModel):
def __init__(self, model, args, label_smoother):
super().__init__()
self._original_model = model
self.args = args
# Label smoothing
self.label_smoother = label_smoother

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs, num_items_in_batch):
# The compute_model_plus_loss_internal is assigned once the class is instantiated.
# It should have same signature as Trainer.compute_loss().
# We do this to avoid potential un-synced states if we duplicated compute loss codes .
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs, num_items_in_batch)

@property
def module(self):
@@ -291,14 +295,14 @@ def _set_signature_columns_if_needed(self):
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def compute_loss(self, model_with_loss, inputs, return_outputs=False):
def compute_loss(self, model_with_loss, inputs, return_outputs=False, num_items_in_batch=None):
# Run model forward + loss compute.
if isinstance(self.model, ModuleWithLoss):
# ORTModule Does not support the BatchEncoding Type so we have to convert to a dict.
dict_inputs = dict(inputs.items())
return model_with_loss(dict_inputs, return_outputs)
return model_with_loss(dict_inputs, return_outputs, num_items_in_batch)
else:
return super().compute_loss(model_with_loss, inputs, return_outputs)
return super().compute_loss(model_with_loss, inputs, return_outputs, num_items_in_batch)

def train(
self,
@@ -508,8 +512,13 @@ def _inner_training_loop(
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc @JingyaHuang who took care of the ort training integrations

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
@@ -798,12 +807,16 @@ def get_dataloader_sampler(dataloader):
self.lr_scheduler.step()

model.zero_grad()
grad_norm: Optional[float] = None
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if is_transformers_version(">=", "4.47.0"):
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
)
else:
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

@@ -818,8 +831,13 @@ def get_dataloader_sampler(dataloader):
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

if is_transformers_version(">=", "4.47.0"):
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
)
else:
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
logger.warning(
"You enabled PyTorch/XLA debug metrics which is not supported by ONNX "
@@ -1072,3 +1090,39 @@ def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, A
else:
raise ValueError(f"ORTTrainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs

def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

supported_classes = (PreTrainedModel,)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()

if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_model(
self.model, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)

if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
1 change: 0 additions & 1 deletion optimum/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,6 @@
is_transformers_available,
is_transformers_version,
require_numpy_strictly_lower,
torch_version,
)
from .input_generators import (
DEFAULT_DUMMY_SHAPES,
55 changes: 32 additions & 23 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -15,11 +15,10 @@

import importlib.metadata
import importlib.util
import inspect
import operator as op
from collections import OrderedDict
from contextlib import contextmanager
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from packaging import version
@@ -37,16 +36,37 @@
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
def _is_package_available(
pkg_name: str,
return_version: bool = False,
alt_pkg_names: Optional[List[str]] = None,
) -> Union[Tuple[bool, str], bool]:
"""
Check if a package is available in the current environment and not just an importable module by checking its version.
Optionally return the version of the package.
Args:
pkg_name (str): The name of the package to check.
return_version (bool): Whether to return the version of the package.
alt_pkg_names (Optional[List[str]]): A list of alternative package names to check if the main package
name is not found.
Returns:
Union[Tuple[bool, str], bool]: A tuple of the package availability and the version of the package if `return_version` is `True`.
"""

package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
for pkg in [pkg_name] + (alt_pkg_names or []):
try:
package_version = importlib.metadata.version(pkg)
package_exists = True
break
except importlib.metadata.PackageNotFoundError:
package_exists = False
pass

if return_version:
return package_exists, package_version
else:
@@ -64,12 +84,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True)
_transformers_available, _transformers_version = _is_package_available("transformers", return_version=True)
_torch_available, _torch_version = _is_package_available("torch", return_version=True)

# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False)

# TODO : Remove
torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None
_onnxruntime_available, _onnxruntime_version = _is_package_available(
"onnxruntime", return_version=True, alt_pkg_names=["onnxruntime-gpu", "onnxruntime-training"]
)


# Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
@@ -168,14 +185,6 @@ def is_onnx_available():


def is_onnxruntime_available():
try:
# Try to import the source file of onnxruntime - if you run the tests from `tests` the function gets
# confused since there a folder named `onnxruntime` in `tests`. Therefore, `_onnxruntime_available`
# will be set to `True` even if not installed.
mod = importlib.import_module("onnxruntime")
inspect.getsourcefile(mod)
except Exception:
return False
return _onnxruntime_available


4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -88,6 +88,10 @@
"executorch>=0.4.0",
"transformers>=4.46",
],
"onnxruntime-training": [
"torch-ort",
"onnxruntime-training",
],
"diffusers": ["diffusers"],
"intel": "optimum-intel>=1.18.0",
"openvino": "optimum-intel[openvino]>=1.18.0",
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
class ORTTrainerExampleTest(unittest.TestCase):
def test_text_classification(self):
subprocess.run(
"cp ../examples/onnxruntime/training/text-classification/run_glue.py ./",
"cp examples/onnxruntime/training/text-classification/run_glue.py ./",
shell=True,
)

@@ -51,7 +51,7 @@ def test_text_classification(self):

def test_token_classification(self):
subprocess.run(
"cp ../examples/onnxruntime/training/token-classification/run_ner.py ./",
"cp examples/onnxruntime/training/token-classification/run_ner.py ./",
shell=True,
)

@@ -75,7 +75,7 @@ def test_token_classification(self):

def test_translation(self):
subprocess.run(
"cp ../examples/onnxruntime/training/translation/run_translation.py ./",
"cp examples/onnxruntime/training/translation/run_translation.py ./",
shell=True,
)

@@ -105,7 +105,7 @@ def test_translation(self):
@pytest.mark.skip(reason="skip for now")
def test_summarization(self):
subprocess.run(
"cp ../examples/onnxruntime/training/summarization/run_summarization.py ./",
"cp examples/onnxruntime/training/summarization/run_summarization.py ./",
shell=True,
)

@@ -139,7 +139,7 @@ def test_stable_diffusion_txt2img(self):
@pytest.mark.skip(reason="skip for now")
def test_question_answering(self):
subprocess.run(
"cp ../examples/onnxruntime/training/question-answering/run_qa.py ./",
"cp examples/onnxruntime/training/question-answering/run_qa.py ./",
shell=True,
)

@@ -166,7 +166,7 @@ def test_question_answering(self):
@pytest.mark.skip(reason="skip for now")
def test_language_modeling(self):
subprocess.run(
"cp ../examples/onnxruntime/training/question-answering/run_qa.py ./",
"cp examples/onnxruntime/training/question-answering/run_qa.py ./",
shell=True,
)

@@ -194,7 +194,7 @@ def test_language_modeling(self):
@pytest.mark.skip(reason="skip for now")
def test_image_classification(self):
subprocess.run(
"cp ../examples/onnxruntime/training/image-classification/run_image_classification.py ./",
"cp examples/onnxruntime/training/image-classification/run_image_classification.py ./",
shell=True,
)

Original file line number Diff line number Diff line change
@@ -60,11 +60,11 @@
nltk.download("punkt")

_ENCODERS_TO_TEST = {
("distilbert", "distilbert-base-cased"),
("distilbert", "distilbert-base-uncased"),
}

_DECODERS_TO_TEST = {
("gpt2", "gpt2"),
("gpt2", "distilgpt2"),
}

_SEQ2SEQ_MODELS_TO_TEST = {
@@ -78,11 +78,6 @@
"data_collator": default_data_collator,
"data_collator_class": DataCollatorWithPadding,
},
# "token-classification": {
# "dataset": ["conll2003"],
# "metric": ["seqeval"],
# "data_collator_class": DataCollatorForTokenClassification,
# },
}

_DECODER_TASKS_DATASETS_CONFIGS = {
@@ -235,7 +230,7 @@ def load_and_prepare(task):

def load_and_prepare_glue(model_name, data_metric_config, max_seq_length, padding="max_length", **kwargs):
# Prepare model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prepare dataset
@@ -295,7 +290,9 @@ def load_and_prepare_ner(model_name, data_metric_config, max_seq_length, padding
label_list = dataset["train"].features[f"{task}_tags"].feature.names

# Prepare model
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(label_list))
model = AutoModelForTokenClassification.from_pretrained(
model_name, num_labels=len(label_list), attn_implementation="eager"
)
if model_name.split("-")[0] in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, add_prefix_space=True)
else:
@@ -387,7 +384,7 @@ def load_and_prepare_clm(model_name, data_metric_config, max_seq_length, padding
metric = load(*data_metric_config["metric"])

# Prepare model
model = AutoModelForCausalLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prepare dataset
@@ -462,7 +459,7 @@ def compute_metrics(eval_pred):

def load_and_prepare_xsum(model_name, data_metric_config, _, **kwargs):
# Prepare model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load dataset and metric
@@ -600,7 +597,7 @@ def test_trainer_fp32(self, test_name, model_name, task, data_metric_config):
trainer.train()
trainer.save_model()
trainer.evaluate()
trainer.predict(test_dataset)
# trainer.predict(test_dataset)
gc.collect()

@slow
@@ -639,7 +636,7 @@ def test_trainer_fp32_with_label_smoothing(self, test_name, model_name, task, da
trainer.train()
trainer.save_model()
trainer.evaluate()
trainer.predict(test_dataset)
# trainer.predict(test_dataset)
gc.collect()

@slow
@@ -678,7 +675,7 @@ def test_trainer_fp16(self, test_name, model_name, task, data_metric_config):
trainer.train()
trainer.save_model()
trainer.evaluate()
trainer.predict(test_dataset)
# trainer.predict(test_dataset)
gc.collect()


@@ -730,7 +727,7 @@ def test_trainer_fp16_ds_stage1(self, test_name, model_name, task, data_metric_c
weight_decay=self.weight_decay,
logging_dir=tmp_dir,
fp16=True,
deepspeed="onnxruntime/ds_configs/ds_config_zero_stage_1.json",
deepspeed="tests/onnxruntime-training/ds_configs/ds_config_zero_stage_1.json",
)

trainer, _ = get_ort_trainer(
@@ -769,7 +766,7 @@ def test_trainer_fp16_ds_stage2(self, test_name, model_name, task, data_metric_c
weight_decay=self.weight_decay,
logging_dir=tmp_dir,
fp16=True,
deepspeed="onnxruntime/ds_configs/ds_config_zero_stage_2.json",
deepspeed="tests/onnxruntime-training/ds_configs/ds_config_zero_stage_2.json",
)

trainer, _ = get_ort_trainer(
26 changes: 0 additions & 26 deletions tests/onnxruntime/docker/Dockerfile_onnxruntime_gpu

This file was deleted.

83 changes: 0 additions & 83 deletions tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer

This file was deleted.

18 changes: 12 additions & 6 deletions tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
@@ -281,16 +281,18 @@ def test_negative_prompt(self, model_arch: str):
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"],
"provider": ["CUDAExecutionProvider", "TensorrtExecutionProvider"],
}
)
)
@pytest.mark.rocm_ep_test
@pytest.mark.cuda_ep_test
@pytest.mark.trt_ep_test
@require_torch_gpu
@require_diffusers
def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str):
if provider == "TensorrtExecutionProvider" and model_arch != self.__class__.SUPPORTED_ARCHITECTURES[0]:
self.skipTest("Testing a single arch for TensorrtExecutionProvider")

model_args = {"test_name": test_name, "model_arch": model_arch}
self._setup(model_args)

@@ -519,16 +521,18 @@ def test_image_reproducibility(self, model_arch: str):
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"],
"provider": ["CUDAExecutionProvider", "TensorrtExecutionProvider"],
}
)
)
@pytest.mark.rocm_ep_test
@pytest.mark.cuda_ep_test
@pytest.mark.trt_ep_test
@require_torch_gpu
@require_diffusers
def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str):
if provider == "TensorrtExecutionProvider" and model_arch != self.__class__.SUPPORTED_ARCHITECTURES[0]:
self.skipTest("Testing a single arch for TensorrtExecutionProvider")

model_args = {"test_name": test_name, "model_arch": model_arch}
self._setup(model_args)

@@ -759,16 +763,18 @@ def test_image_reproducibility(self, model_arch: str):
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"provider": ["CUDAExecutionProvider", "ROCMExecutionProvider", "TensorrtExecutionProvider"],
"provider": ["CUDAExecutionProvider", "TensorrtExecutionProvider"],
}
)
)
@pytest.mark.rocm_ep_test
@pytest.mark.cuda_ep_test
@pytest.mark.trt_ep_test
@require_torch_gpu
@require_diffusers
def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str):
if provider == "TensorrtExecutionProvider" and model_arch != self.__class__.SUPPORTED_ARCHITECTURES[0]:
self.skipTest("Testing a single arch for TensorrtExecutionProvider")

model_args = {"test_name": test_name, "model_arch": model_arch}
self._setup(model_args)

907 changes: 450 additions & 457 deletions tests/onnxruntime/test_modeling.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
@@ -182,6 +182,9 @@ class ORTModelTestMixin(unittest.TestCase):
"np": np.ndarray,
}

ATOL = 1e-4
RTOL = 1e-4

TASK = None

ORTMODEL_CLASS = None