diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 61911fc6d4..983f1f6850 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -422,7 +422,7 @@ def get_encoder(self): return self.encoder def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - self.decoder._reorder_cache(past, beam_idx) + return self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -627,6 +627,7 @@ def forward( if self.stateful and past_key_values is None: self.request.reset_state() self._past_length = 0 + self.next_beam_idx = np.arange(input_ids.shape[0], dtype=int) if past_key_values is not None and not self.stateful: # Flatten the past_key_values @@ -661,7 +662,6 @@ def forward( inputs["beam_idx"] = ( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32) ) - # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -1016,7 +1016,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio auto_model_class = WhisperForConditionalGeneration # force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods - prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation generate = WhisperForConditionalGeneration.generate @classmethod @@ -1083,7 +1082,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: - self.decoder._get_past_length(past_key_values) + past_length = self.decoder._get_past_length(past_key_values) # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index efc84ee76a..9da0069706 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1658,6 +1658,21 @@ def test_compare_to_transformers(self, model_arch): transformers_outputs = transformers_model(**tokens, **decoder_inputs) # Compare tensor outputs self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=2, + do_sample=False, + eos_token_id=None, + ) + + set_seed(SEED) + generated_tokens = transformers_model.generate(**tokens, generation_config=gen_config) + set_seed(SEED) + ov_generated_tokens = ov_model.generate(**tokens, generation_config=gen_config) + + self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens)) + del transformers_model del ov_model @@ -2355,12 +2370,12 @@ def test_compare_to_transformers(self, model_arch): processor = get_preprocessor(model_id) data = self._generate_random_audio_data() - features = processor.feature_extractor(data, return_tensors="pt") + pt_features = processor.feature_extractor(data, return_tensors="pt") decoder_start_token_id = transformers_model.config.decoder_start_token_id decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} with torch.no_grad(): - transformers_outputs = transformers_model(**features, **decoder_inputs) + transformers_outputs = transformers_model(**pt_features, **decoder_inputs) for input_type in ["pt", "np"]: features = processor.feature_extractor(data, return_tensors=input_type) @@ -2373,6 +2388,21 @@ def test_compare_to_transformers(self, model_arch): # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3)) + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=2, + do_sample=False, + eos_token_id=None, + ) + + set_seed(SEED) + generated_tokens = transformers_model.generate(**pt_features, generation_config=gen_config) + set_seed(SEED) + ov_generated_tokens = ov_model.generate(**pt_features, generation_config=gen_config) + + self.assertTrue(torch.equal(generated_tokens, ov_generated_tokens)) + del transformers_model del ov_model gc.collect()