@@ -422,7 +422,7 @@ def get_encoder(self):
422
422
return self .encoder
423
423
424
424
def _reorder_cache (self , past , beam_idx ) -> Tuple [Tuple [torch .FloatTensor ]]:
425
- self .decoder ._reorder_cache (past , beam_idx )
425
+ return self .decoder ._reorder_cache (past , beam_idx )
426
426
427
427
def reshape (self , batch_size : int , sequence_length : int ):
428
428
"""
@@ -627,6 +627,7 @@ def forward(
627
627
if self .stateful and past_key_values is None :
628
628
self .request .reset_state ()
629
629
self ._past_length = 0
630
+ self .next_beam_idx = np .arange (input_ids .shape [0 ], dtype = int )
630
631
631
632
if past_key_values is not None and not self .stateful :
632
633
# Flatten the past_key_values
@@ -661,7 +662,6 @@ def forward(
661
662
inputs ["beam_idx" ] = (
662
663
self .next_beam_idx if self .next_beam_idx is not None else np .arange (batch_size , dtype = np .int32 )
663
664
)
664
-
665
665
# Run inference
666
666
self .request .start_async (inputs , share_inputs = True )
667
667
self .request .wait ()
@@ -1016,7 +1016,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio
1016
1016
auto_model_class = WhisperForConditionalGeneration
1017
1017
1018
1018
# force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods
1019
- prepare_inputs_for_generation = WhisperForConditionalGeneration .prepare_inputs_for_generation
1020
1019
generate = WhisperForConditionalGeneration .generate
1021
1020
1022
1021
@classmethod
@@ -1083,7 +1082,7 @@ def prepare_inputs_for_generation(
1083
1082
1084
1083
past_length = 0
1085
1084
if past_key_values is not None :
1086
- self .decoder ._get_past_length (past_key_values )
1085
+ past_length = self .decoder ._get_past_length (past_key_values )
1087
1086
1088
1087
# Some generation methods already pass only the last input ID
1089
1088
if decoder_input_ids .shape [1 ] > past_length :
0 commit comments