Skip to content

Commit e4cc078

Browse files
committed
fix beam search in seq2seq
1 parent 248aabd commit e4cc078

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

optimum/intel/openvino/modeling_seq2seq.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def get_encoder(self):
422422
return self.encoder
423423

424424
def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
425-
self.decoder._reorder_cache(past, beam_idx)
425+
return self.decoder._reorder_cache(past, beam_idx)
426426

427427
def reshape(self, batch_size: int, sequence_length: int):
428428
"""
@@ -627,6 +627,7 @@ def forward(
627627
if self.stateful and past_key_values is None:
628628
self.request.reset_state()
629629
self._past_length = 0
630+
self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int)
630631

631632
if past_key_values is not None and not self.stateful:
632633
# Flatten the past_key_values
@@ -661,7 +662,6 @@ def forward(
661662
inputs["beam_idx"] = (
662663
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32)
663664
)
664-
665665
# Run inference
666666
self.request.start_async(inputs, share_inputs=True)
667667
self.request.wait()
@@ -1016,7 +1016,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio
10161016
auto_model_class = WhisperForConditionalGeneration
10171017

10181018
# force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods
1019-
prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation
10201019
generate = WhisperForConditionalGeneration.generate
10211020

10221021
@classmethod
@@ -1083,7 +1082,7 @@ def prepare_inputs_for_generation(
10831082

10841083
past_length = 0
10851084
if past_key_values is not None:
1086-
self.decoder._get_past_length(past_key_values)
1085+
past_length = self.decoder._get_past_length(past_key_values)
10871086

10881087
# Some generation methods already pass only the last input ID
10891088
if decoder_input_ids.shape[1] > past_length:

0 commit comments

Comments
 (0)