Skip to content

Commit a11c6c8

Browse files
authored
fix beam search in seq2seq (#1111)
* fix beam search in seq2seq * add tests
1 parent 248aabd commit a11c6c8

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

optimum/intel/openvino/modeling_seq2seq.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def get_encoder(self):
422422
return self.encoder
423423

424424
def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
425-
self.decoder._reorder_cache(past, beam_idx)
425+
return self.decoder._reorder_cache(past, beam_idx)
426426

427427
def reshape(self, batch_size: int, sequence_length: int):
428428
"""
@@ -627,6 +627,7 @@ def forward(
627627
if self.stateful and past_key_values is None:
628628
self.request.reset_state()
629629
self._past_length = 0
630+
self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)
630631

631632
if past_key_values is not None and not self.stateful:
632633
# Flatten the past_key_values
@@ -661,7 +662,6 @@ def forward(
661662
inputs["beam_idx"] = (
662663
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32)
663664
)
664-
665665
# Run inference
666666
self.request.start_async(inputs, share_inputs=True)
667667
self.request.wait()
@@ -1016,7 +1016,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio
10161016
auto_model_class = WhisperForConditionalGeneration
10171017

10181018
# force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods
1019-
prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation
10201019
generate = WhisperForConditionalGeneration.generate
10211020

10221021
@classmethod
@@ -1083,7 +1082,7 @@ def prepare_inputs_for_generation(
10831082

10841083
past_length = 0
10851084
if past_key_values is not None:
1086-
self.decoder._get_past_length(past_key_values)
1085+
past_length = self.decoder._get_past_length(past_key_values)
10871086

10881087
# Some generation methods already pass only the last input ID
10891088
if decoder_input_ids.shape[1] > past_length:

tests/openvino/test_modeling.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,21 @@ def test_compare_to_transformers(self, model_arch):
16581658
transformers_outputs = transformers_model(**tokens, **decoder_inputs)
16591659
# Compare tensor outputs
16601660
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
1661+
gen_config = GenerationConfig(
1662+
max_new_tokens=10,
1663+
min_new_tokens=10,
1664+
num_beams=2,
1665+
do_sample=False,
1666+
eos_token_id=None,
1667+
)
1668+
1669+
set_seed(SEED)
1670+
generated_tokens = transformers_model.generate(**tokens, generation_config=gen_config)
1671+
set_seed(SEED)
1672+
ov_generated_tokens = ov_model.generate(**tokens, generation_config=gen_config)
1673+
1674+
self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))
1675+
16611676
del transformers_model
16621677
del ov_model
16631678

@@ -2355,12 +2370,12 @@ def test_compare_to_transformers(self, model_arch):
23552370

23562371
processor = get_preprocessor(model_id)
23572372
data = self._generate_random_audio_data()
2358-
features = processor.feature_extractor(data, return_tensors="pt")
2373+
pt_features = processor.feature_extractor(data, return_tensors="pt")
23592374
decoder_start_token_id = transformers_model.config.decoder_start_token_id
23602375
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}
23612376

23622377
with torch.no_grad():
2363-
transformers_outputs = transformers_model(**features, **decoder_inputs)
2378+
transformers_outputs = transformers_model(**pt_features, **decoder_inputs)
23642379

23652380
for input_type in ["pt", "np"]:
23662381
features = processor.feature_extractor(data, return_tensors=input_type)
@@ -2373,6 +2388,21 @@ def test_compare_to_transformers(self, model_arch):
23732388
# Compare tensor outputs
23742389
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))
23752390

2391+
gen_config = GenerationConfig(
2392+
max_new_tokens=10,
2393+
min_new_tokens=10,
2394+
num_beams=2,
2395+
do_sample=False,
2396+
eos_token_id=None,
2397+
)
2398+
2399+
set_seed(SEED)
2400+
generated_tokens = transformers_model.generate(**pt_features, generation_config=gen_config)
2401+
set_seed(SEED)
2402+
ov_generated_tokens = ov_model.generate(**pt_features, generation_config=gen_config)
2403+
2404+
self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens))
2405+
23762406
del transformers_model
23772407
del ov_model
23782408
gc.collect()

0 commit comments

Comments
 (0)