Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix starcoder ORT integration #1722

Merged
merged 2 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments



""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)
Expand Down
5 changes: 4 additions & 1 deletion optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def forward(

outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)

model_inputs = [input_ids]
# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g
model_inputs = [input_ids.contiguous()]

if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)
Expand Down
7 changes: 7 additions & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,13 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
)
return model_inputs

# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)


class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
Expand Down
85 changes: 65 additions & 20 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,8 +2322,8 @@ def test_merge_from_onnx_and_save(self, model_arch):
self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents)
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)

@parameterized.expand(grid_parameters(FULL_GRID))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool):
@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]}))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
use_io_binding = None
if use_cache is False:
use_io_binding = False
Expand Down Expand Up @@ -2384,17 +2384,19 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
if model_arch == "falcon":
# TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers
new_tokens = 5

onnx_outputs = onnx_model.generate(
**tokens,
num_beams=1,
num_beams=num_beams,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
eos_token_id=None,
)

transformers_outputs = transformers_model.generate(
**tokens,
num_beams=1,
num_beams=num_beams,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
Expand Down Expand Up @@ -4123,11 +4125,23 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
@require_torch_gpu
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand Down Expand Up @@ -4159,8 +4173,8 @@ def test_compare_generation_to_io_binding(

tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda")
onnx_outputs = onnx_model.generate(**tokens, num_beams=5)
io_outputs = io_model.generate(**tokens, num_beams=5)
onnx_outputs = onnx_model.generate(**tokens, num_beams=num_beams)
io_outputs = io_model.generate(**tokens, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -4555,12 +4569,24 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 5],
}
)
)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -4586,8 +4612,8 @@ def test_compare_generation_to_io_binding(
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt").to("cuda")

onnx_outputs = onnx_model.generate(**features, num_beams=5)
io_outputs = io_model.generate(**features, num_beams=5)
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
io_outputs = io_model.generate(**features, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -4920,12 +4946,19 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, num_beams: int
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -4951,8 +4984,8 @@ def test_compare_generation_to_io_binding(
data = self._get_sample_image()
features = feature_extractor(data, return_tensors="pt").to("cuda")

onnx_outputs = onnx_model.generate(**features, num_beams=5)
io_outputs = io_model.generate(**features, num_beams=5)
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
io_outputs = io_model.generate(**features, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -5336,10 +5369,22 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -5362,8 +5407,8 @@ def test_compare_generation_to_io_binding(
inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt")
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]
onnx_outputs = onnx_model.generate(**inputs, num_beams=5)
io_outputs = io_model.generate(**inputs, num_beams=5)
onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams)
io_outputs = io_model.generate(**inputs, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down
Loading