Skip to content

Commit c7cc312

Browse files
authored
Fix starcoder ORT integration (#1722)
* fix starcoder ort * fix pix2struct as well
1 parent 80e89f1 commit c7cc312

File tree

3 files changed

+76
-21
lines changed

3 files changed

+76
-21
lines changed

optimum/onnxruntime/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def forward(
260260

261261
outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)
262262

263-
model_inputs = [input_ids]
263+
# TODO: fix transformers generate to have contiguous input_ids here already
264+
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
265+
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g
266+
model_inputs = [input_ids.contiguous()]
264267

265268
if "encoder_hidden_states" in self.input_names:
266269
model_inputs.append(encoder_hidden_states)

optimum/onnxruntime/modeling_decoder.py

+7
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,13 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
722722
)
723723
return model_inputs
724724

725+
# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
726+
@staticmethod
727+
def _reorder_cache(
728+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
729+
) -> Tuple[Tuple[torch.Tensor]]:
730+
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
731+
725732

726733
class ORTBloomForCausalLM(ORTModelForCausalLM):
727734
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation

tests/onnxruntime/test_modeling.py

+65-20
Original file line numberDiff line numberDiff line change
@@ -2322,8 +2322,8 @@ def test_merge_from_onnx_and_save(self, model_arch):
23222322
self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents)
23232323
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)
23242324

2325-
@parameterized.expand(grid_parameters(FULL_GRID))
2326-
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool):
2325+
@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]}))
2326+
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
23272327
use_io_binding = None
23282328
if use_cache is False:
23292329
use_io_binding = False
@@ -2384,17 +2384,19 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
23842384
if model_arch == "falcon":
23852385
# TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers
23862386
new_tokens = 5
2387+
23872388
onnx_outputs = onnx_model.generate(
23882389
**tokens,
2389-
num_beams=1,
2390+
num_beams=num_beams,
23902391
do_sample=False,
23912392
min_new_tokens=new_tokens,
23922393
max_new_tokens=new_tokens,
23932394
eos_token_id=None,
23942395
)
2396+
23952397
transformers_outputs = transformers_model.generate(
23962398
**tokens,
2397-
num_beams=1,
2399+
num_beams=num_beams,
23982400
do_sample=False,
23992401
min_new_tokens=new_tokens,
24002402
max_new_tokens=new_tokens,
@@ -4123,11 +4125,23 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
41234125
gc.collect()
41244126

41254127
@parameterized.expand(
4126-
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
4128+
grid_parameters(
4129+
{
4130+
"model_arch": SUPPORTED_ARCHITECTURES,
4131+
"use_cache": [True],
4132+
"use_merged": [False, True],
4133+
"num_beams": [1, 3],
4134+
}
4135+
)
41274136
)
41284137
@require_torch_gpu
41294138
def test_compare_generation_to_io_binding(
4130-
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
4139+
self,
4140+
test_name: str,
4141+
model_arch: str,
4142+
use_cache: bool,
4143+
use_merged: bool,
4144+
num_beams: int,
41314145
):
41324146
if use_cache is False and use_merged is True:
41334147
self.skipTest("use_cache=False, use_merged=True are uncompatible")
@@ -4159,8 +4173,8 @@ def test_compare_generation_to_io_binding(
41594173

41604174
tokenizer = get_preprocessor(model_id)
41614175
tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda")
4162-
onnx_outputs = onnx_model.generate(**tokens, num_beams=5)
4163-
io_outputs = io_model.generate(**tokens, num_beams=5)
4176+
onnx_outputs = onnx_model.generate(**tokens, num_beams=num_beams)
4177+
io_outputs = io_model.generate(**tokens, num_beams=num_beams)
41644178

41654179
# compare tensor outputs
41664180
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
@@ -4555,12 +4569,24 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
45554569
gc.collect()
45564570

45574571
@parameterized.expand(
4558-
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
4572+
grid_parameters(
4573+
{
4574+
"model_arch": SUPPORTED_ARCHITECTURES,
4575+
"use_cache": [True],
4576+
"use_merged": [False, True],
4577+
"num_beams": [1, 5],
4578+
}
4579+
)
45594580
)
45604581
@require_torch_gpu
45614582
@pytest.mark.cuda_ep_test
45624583
def test_compare_generation_to_io_binding(
4563-
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
4584+
self,
4585+
test_name: str,
4586+
model_arch: str,
4587+
use_cache: bool,
4588+
use_merged: bool,
4589+
num_beams: int,
45644590
):
45654591
if use_cache is False and use_merged is True:
45664592
self.skipTest("use_cache=False, use_merged=True are uncompatible")
@@ -4586,8 +4612,8 @@ def test_compare_generation_to_io_binding(
45864612
data = self._generate_random_audio_data()
45874613
features = processor.feature_extractor(data, return_tensors="pt").to("cuda")
45884614

4589-
onnx_outputs = onnx_model.generate(**features, num_beams=5)
4590-
io_outputs = io_model.generate(**features, num_beams=5)
4615+
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
4616+
io_outputs = io_model.generate(**features, num_beams=num_beams)
45914617

45924618
# compare tensor outputs
45934619
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
@@ -4920,12 +4946,19 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
49204946
gc.collect()
49214947

49224948
@parameterized.expand(
4923-
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
4949+
grid_parameters(
4950+
{
4951+
"model_arch": SUPPORTED_ARCHITECTURES,
4952+
"use_cache": [True],
4953+
"use_merged": [False, True],
4954+
"num_beams": [1, 3],
4955+
}
4956+
)
49244957
)
49254958
@require_torch_gpu
49264959
@pytest.mark.cuda_ep_test
49274960
def test_compare_generation_to_io_binding(
4928-
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
4961+
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, num_beams: int
49294962
):
49304963
if use_cache is False and use_merged is True:
49314964
self.skipTest("use_cache=False, use_merged=True are uncompatible")
@@ -4951,8 +4984,8 @@ def test_compare_generation_to_io_binding(
49514984
data = self._get_sample_image()
49524985
features = feature_extractor(data, return_tensors="pt").to("cuda")
49534986

4954-
onnx_outputs = onnx_model.generate(**features, num_beams=5)
4955-
io_outputs = io_model.generate(**features, num_beams=5)
4987+
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
4988+
io_outputs = io_model.generate(**features, num_beams=num_beams)
49564989

49574990
# compare tensor outputs
49584991
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
@@ -5336,10 +5369,22 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
53365369
gc.collect()
53375370

53385371
@parameterized.expand(
5339-
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
5372+
grid_parameters(
5373+
{
5374+
"model_arch": SUPPORTED_ARCHITECTURES,
5375+
"use_cache": [True],
5376+
"use_merged": [False, True],
5377+
"num_beams": [1, 3],
5378+
}
5379+
)
53405380
)
53415381
def test_compare_generation_to_io_binding(
5342-
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
5382+
self,
5383+
test_name: str,
5384+
model_arch: str,
5385+
use_cache: bool,
5386+
use_merged: bool,
5387+
num_beams: int,
53435388
):
53445389
if use_cache is False and use_merged is True:
53455390
self.skipTest("use_cache=False, use_merged=True are uncompatible")
@@ -5362,8 +5407,8 @@ def test_compare_generation_to_io_binding(
53625407
inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt")
53635408
del inputs["decoder_attention_mask"]
53645409
del inputs["decoder_input_ids"]
5365-
onnx_outputs = onnx_model.generate(**inputs, num_beams=5)
5366-
io_outputs = io_model.generate(**inputs, num_beams=5)
5410+
onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams)
5411+
io_outputs = io_model.generate(**inputs, num_beams=num_beams)
53675412

53685413
# compare tensor outputs
53695414
self.assertTrue(torch.equal(onnx_outputs, io_outputs))

0 commit comments

Comments
 (0)