diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 9ab494be6b..e4dc1ed784 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -17,7 +17,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino @@ -28,6 +28,10 @@ from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin +from transformers.generation.configuration_utils import GenerationConfig, GenerationMode +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast from optimum.utils.normalized_config import NormalizedConfigManager @@ -41,6 +45,11 @@ from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.streamers import BaseStreamer + + logger = logging.getLogger(__name__) core = Core() @@ -122,6 +131,8 @@ def __init__( self._pkv_precision = Type.f32 self.next_beam_idx = None self._past_length = 0 + self._first_iter_beam_search = False + self._second_iter_beam_search = False self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) @@ -375,7 +386,11 @@ def prepare_inputs( inputs = {} if not self.stateful: if past_key_values is not None: - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( @@ -418,7 +433,6 @@ def prepare_inputs( self.next_beam_idx = np.arange(batch_size, dtype=int) self._past_length = 0 past_len = self._get_past_length(past_key_values) - inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed if "attention_mask" in self.input_names or "position_ids" in self.input_names: @@ -468,6 +482,8 @@ def forward( **kwargs, ) + if self._first_iter_beam_search: + inputs, duplication_indices = self._deduplicate_inputs(inputs) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -483,7 +499,11 @@ def forward( if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) @@ -491,6 +511,10 @@ def forward( else: past_key_values = None + if self._first_iter_beam_search: + logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values) + self._first_iter_beam_search = False + return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation @@ -520,7 +544,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - return { + model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, @@ -528,12 +552,116 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "attention_mask": attention_mask, } + return model_inputs + + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] + if indicies.shape[0] != 1: + logits = logits[indicies] + if past_key_values and not self.stateful: + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): + past_key_values = tuple( + tuple( + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] + for past_state in layer_past + ) + for layer_past in past_key_values + ) + else: + past_key_values = tuple([past_state[indicies] for past_state in past_key_values]) + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True + return logits, past_key_values + + def _deduplicate_inputs(self, model_inputs: Dict): + input_ids = model_inputs["input_ids"] + upd_model_inputs = {} + unique_input_ids, indicies, reverse_indicies = np.unique( + input_ids, axis=0, return_index=True, return_inverse=True + ) + for input_name, input_tensor in model_inputs.items(): + if input_name not in ["input_ids", "beam_idx"]: + if not isinstance(input_tensor, Tensor): + upd_model_inputs[input_name] = input_tensor[indicies] + else: + shape = input_tensor.shape + dtype = input_tensor.element_type + upd_batch_size = indicies.shape[0] + if self.config.model_type == "bloom": + upd_batch_size *= self.config.num_attention_heads + shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size + upd_model_inputs[input_name] = Tensor(dtype, shape) + upd_model_inputs["input_ids"] = unique_input_ids + if "beam_idx" in model_inputs: + beam_range = ( + unique_input_ids.shape[0] + if self.config.model_type != "bloom" + else unique_input_ids.shape[0] * self.config.num_attention_heads + ) + beam_idx = np.arange(beam_range, dtype=int) + upd_model_inputs["beam_idx"] = beam_idx + return upd_model_inputs, reverse_indicies + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + generation_mode = _generation_config.get_generation_mode(assistant_model) + + is_beam_search = generation_mode in [ + GenerationMode.BEAM_SEARCH, + GenerationMode.BEAM_SAMPLE, + GenerationMode.GROUP_BEAM_SEARCH, + GenerationMode.CONSTRAINED_BEAM_SEARCH, + ] + if is_beam_search: + self._first_iter_beam_search = True + result = super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + assistant_model, + streamer, + negative_prompt_ids, + negative_prompt_attention_mask, + **kwargs, + ) + return result + def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 if self.stateful: return self._past_length - if self.config.model_type in MULTI_QUERY_ATTN_MODELS: + if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): return past_key_values[0].shape[-2] seq_length_dim = -2 if self.config.model_type == "chatglm": @@ -558,12 +686,20 @@ def _reorder_cache( if self.stateful: # TODO: Apply it differently based on model type # TODO: At least for bloom we need to replicate values for each attention head - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = ( + np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + ) # save beam_idx to be used as an input in the next iteration + self._second_iter_beam_search = False return past_key_values else: - return tuple( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): + return tuple( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) + for layer_past in past_key_values + ) + return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values) def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" @@ -684,11 +820,12 @@ def _reorder_cache( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ if self.stateful: - beam_idx = np.array(beam_idx) batch_size = beam_idx.shape[0] + beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx indices = np.array(range(batch_size * self.config.num_attention_heads)) indices = indices.reshape([batch_size, self.config.num_attention_heads]) self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() + self._second_iter_beam_search = False return past_key_values else: standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) @@ -738,6 +875,24 @@ def _convert_to_standard_cache( for layer_past in past_key_value ) + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] + if indicies.shape[0] != 1: + logits = logits[indicies] + if past_key_values and not self.stateful: + pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size) + pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard) + past_key_values = self._convert_to_bloom_cache(pkv) + + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True + return logits, past_key_values + class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache @@ -745,7 +900,9 @@ def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: if self.stateful: - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + self._second_iter_beam_search = False return past_key_values else: return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index d4f55c683b..75c95c1563 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self): del model_with_cache gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow + def test_beam_search(self, model_arch): + model_kwargs = {} + model_id = MODEL_NAMES[model_arch] + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + # Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search + if model_arch in ["qwen", "chatglm"]: + return + + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + ) + beam_sample_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=True, + eos_token_id=None, + top_k=1, + ) + + group_beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + num_beam_groups=2, + diversity_penalty=0.0000001, + ) + force_word = "cat" + force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids] + constrained_beam_search_gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + force_words_ids=force_words_ids, + ) + + gen_configs = [ + beam_search_gen_config, + beam_sample_gen_config, + group_beam_search_gen_config, + constrained_beam_search_gen_config, + ] + ov_model_stateful = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=True, **model_kwargs + ) + ov_model_stateless = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=False, **model_kwargs + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer.pad_token_id = tokenizer.eos_token_id + tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + ov_model_stateful.generation_config.eos_token_id = None + ov_model_stateless.generation_config.eos_token_id = None + transformers_model.generation_config.eos_token_id = None + ov_model_stateful.config.eos_token_id = None + ov_model_stateless.config.eos_token_id = None + transformers_model.config.eos_token_id = None + + for gen_config in gen_configs: + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + class OVModelForMaskedLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = (