From d363a00fbfafbcbb67e0c0033f26637f16d12dfc Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 29 Apr 2024 22:04:33 +0400 Subject: [PATCH 1/8] WIP: beam search only --- optimum/intel/openvino/modeling_decoder.py | 460 ++++++++++++++++++++- 1 file changed, 457 insertions(+), 3 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 9ab494be6b..c5368bca2a 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -17,7 +17,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, Any, List, Callable, TYPE_CHECKING import numpy as np import openvino @@ -29,6 +29,11 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.configuration_utils import GenerationConfig, GenerationMode +from transformers.generation.utils import GenerateOutput, GenerateBeamDecoderOnlyOutput, GenerateBeamOutput, _split_model_inputs, stack_model_outputs +from transformers.generation.stopping_criteria import StoppingCriteriaList, EosTokenCriteria +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.beam_search import BeamScorer from optimum.utils.normalized_config import NormalizedConfigManager @@ -40,6 +45,10 @@ from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.streamers import BaseStreamer + logger = logging.getLogger(__name__) @@ -519,8 +528,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - - return { + + model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, @@ -528,6 +537,88 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "attention_mask": attention_mask, } + return model_inputs + + @staticmethod + def _fake_expand_inputs_for_generation( + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + + # postpone expanding inputs for next iterations + + return input_ids, model_kwargs + + + def _expand_outputs_for_generation( + self, + expand_size: int, + logits: torch.Tensor, + past_key_values: Tuple + ): + if expand_size != 1: + logits = logits.repeat_interleave(expand_size, dim=0) + if past_key_values and not self.stateful: + past_key_values = tuple( + tuple(past_state.repeat(expand_size, axis=0 if not self.config.model_type == "chatglm" else 1) for past_state in layer_past) for layer_past in past_key_values + ) + if self.stateful: + self.next_beam_idx = self.next_beam_idx.repeat(expand_size) + return logits, past_key_values + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + self._orig_expand_inputs_for_generation = self._expand_inputs_for_generation + self._expand_inputs_for_generation = self._fake_expand_inputs_for_generation + result = super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + assistant_model, + streamer, + negative_prompt_ids, + negative_prompt_attention_mask, + **kwargs) + self._expand_inputs_for_generation = self._orig_expand_inputs_for_generation + return result + + + def _update_inputs_for_beam_search(self, model_inputs: Dict[str, torch.Tensor], input_expand_size: int): + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None: + if isinstance(dict_to_expand[key], torch.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(input_expand_size, dim=0) + if isinstance(dict_to_expand[key], np.ndarray): + dict_to_expand[key] = dict_to_expand[key].repeat(input_expand_size, dim=0) + return dict_to_expand + + input_ids = model_inputs.pop("input_ids", None) + model_inputs = _expand_dict_for_generation(model_inputs) + if input_ids is not None: + model_inputs["input_ids"] = input_ids + return model_inputs + + def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 @@ -663,6 +754,369 @@ def _from_pretrained( return causal_model + def _beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + sequential: Optional[bool] = None, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate() + instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + sequential (`bool`, defaults to `False`): + By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for + more details). This flag will avoid parallelizing the beam search and will instead run beam search + sequentially. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ] + ... ) + + >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Wie alt bist du?'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + sequential = sequential if sequential is not None else self.generation_config.low_memory + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(num_beams * batch_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + first_iteration = True + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # if sequential is True, split the input to batches of batch_size and run sequentially + if sequential: + inputs_per_sub_batches = _split_model_inputs( + model_inputs, split_size=batch_size, full_batch_size=batch_beam_size if not self._first_iteration else batch_size + ) + outputs_per_sub_batch = [ + self( + **inputs_per_sub_batch, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + for inputs_per_sub_batch in inputs_per_sub_batches + ] + + outputs = stack_model_outputs(outputs_per_sub_batch) + + else: # Unchanged original behavior + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + if first_iteration: + input_ids = input_ids.repeat_interleave(num_beams, dim=0) + model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) + logits, past_key_values = self._expand_outputs_for_generation(num_beams, outputs.logits, outputs.past_key_values) + outputs.logits = logits + outputs.past_key_values = past_key_values + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + next_token_scores = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if model_kwargs.get("past_key_values", None) is not None and not first_iteration: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + first_iteration = False + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation From f263f3f9388ff8c682a35b3cbbd4aad6fc0592ce Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 30 Apr 2024 09:25:50 +0400 Subject: [PATCH 2/8] other beam search algos --- optimum/intel/openvino/modeling_decoder.py | 946 +++++++++++++++++---- 1 file changed, 776 insertions(+), 170 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index c5368bca2a..4cd9bb8303 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -17,7 +17,7 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union, Any, List, Callable, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino @@ -28,12 +28,22 @@ from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.beam_search import BeamScorer, ConstrainedBeamSearchScorer from transformers.generation.configuration_utils import GenerationConfig, GenerationMode -from transformers.generation.utils import GenerateOutput, GenerateBeamDecoderOnlyOutput, GenerateBeamOutput, _split_model_inputs, stack_model_outputs -from transformers.generation.stopping_criteria import StoppingCriteriaList, EosTokenCriteria from transformers.generation.logits_process import LogitsProcessorList -from transformers.generation.beam_search import BeamScorer +from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) +from transformers.generation.utils import ( + GenerateBeamDecoderOnlyOutput, + GenerateBeamOutput, + GenerateOutput, + _split_model_inputs, + stack_model_outputs, +) +from transformers.modeling_outputs import CausalLMOutputWithPast from optimum.utils.normalized_config import NormalizedConfigManager @@ -45,6 +55,7 @@ from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE + if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel from transformers.streamers import BaseStreamer @@ -528,7 +539,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - + model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, @@ -552,18 +563,16 @@ def _fake_expand_inputs_for_generation( return input_ids, model_kwargs - - def _expand_outputs_for_generation( - self, - expand_size: int, - logits: torch.Tensor, - past_key_values: Tuple - ): + def _expand_outputs_for_generation(self, expand_size: int, logits: torch.Tensor, past_key_values: Tuple): if expand_size != 1: logits = logits.repeat_interleave(expand_size, dim=0) if past_key_values and not self.stateful: past_key_values = tuple( - tuple(past_state.repeat(expand_size, axis=0 if not self.config.model_type == "chatglm" else 1) for past_state in layer_past) for layer_past in past_key_values + tuple( + past_state.repeat(expand_size, axis=0 if not self.config.model_type == "chatglm" else 1) + for past_state in layer_past + ) + for layer_past in past_key_values ) if self.stateful: self.next_beam_idx = self.next_beam_idx.repeat(expand_size) @@ -584,23 +593,36 @@ def generate( negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: - self._orig_expand_inputs_for_generation = self._expand_inputs_for_generation - self._expand_inputs_for_generation = self._fake_expand_inputs_for_generation - result = super().generate( - inputs, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - assistant_model, - streamer, - negative_prompt_ids, - negative_prompt_attention_mask, - **kwargs) - self._expand_inputs_for_generation = self._orig_expand_inputs_for_generation - return result + _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + generation_mode = _generation_config.get_generation_mode(assistant_model) + + is_beam_search = generation_mode in [ + GenerationMode.BEAM_SEARCH, + GenerationMode.BEAM_SAMPLE, + GenerationMode.GROUP_BEAM_SEARCH, + GenerationMode.CONSTRAINED_BEAM_SEARCH, + ] + + if is_beam_search: + self._orig_expand_inputs_for_generation = self._expand_inputs_for_generation + self._expand_inputs_for_generation = self._fake_expand_inputs_for_generation + result = super().generate( + inputs, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + assistant_model, + streamer, + negative_prompt_ids, + negative_prompt_attention_mask, + **kwargs, + ) + if is_beam_search: + self._expand_inputs_for_generation = self._orig_expand_inputs_for_generation + return result def _update_inputs_for_beam_search(self, model_inputs: Dict[str, torch.Tensor], input_expand_size: int): def _expand_dict_for_generation(dict_to_expand): @@ -611,13 +633,12 @@ def _expand_dict_for_generation(dict_to_expand): if isinstance(dict_to_expand[key], np.ndarray): dict_to_expand[key] = dict_to_expand[key].repeat(input_expand_size, dim=0) return dict_to_expand - + input_ids = model_inputs.pop("input_ids", None) model_inputs = _expand_dict_for_generation(model_inputs) if input_ids is not None: model_inputs["input_ids"] = input_ids return model_inputs - def _get_past_length(self, past_key_values=None): if past_key_values is None: @@ -772,120 +793,6 @@ def _beam_search( sequential: Optional[bool] = None, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - - In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate() - instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_logits (`bool`, *optional*, defaults to `False`): - Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for - more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - sequential (`bool`, defaults to `False`): - By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for - more details). This flag will avoid parallelizing the beam search and will instead run beam search - sequentially. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() sequential = sequential if sequential is not None else self.generation_config.low_memory @@ -908,8 +815,6 @@ def _beam_search( ) stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: - # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever eos_token_id = [ criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") ] @@ -952,13 +857,6 @@ def _beam_search( cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) @@ -976,7 +874,9 @@ def _beam_search( # if sequential is True, split the input to batches of batch_size and run sequentially if sequential: inputs_per_sub_batches = _split_model_inputs( - model_inputs, split_size=batch_size, full_batch_size=batch_beam_size if not self._first_iteration else batch_size + model_inputs, + split_size=batch_size, + full_batch_size=batch_beam_size if not self._first_iteration else batch_size, ) outputs_per_sub_batch = [ self( @@ -1000,14 +900,12 @@ def _beam_search( if first_iteration: input_ids = input_ids.repeat_interleave(num_beams, dim=0) model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) - logits, past_key_values = self._expand_outputs_for_generation(num_beams, outputs.logits, outputs.past_key_values) + logits, past_key_values = self._expand_outputs_for_generation( + num_beams, outputs.logits, outputs.past_key_values + ) outputs.logits = logits outputs.past_key_values = past_key_values - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] next_token_scores = torch.nn.functional.log_softmax( next_token_logits, dim=-1 @@ -1086,7 +984,227 @@ def _beam_search( if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): this_peer_finished = True - + + first_iteration = False + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _beam_sample( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(batch_size * num_beams)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + first_iteration = True + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if first_iteration: + input_ids = input_ids.repeat_interleave(num_beams, dim=0) + model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) + logits, past_key_values = self._expand_outputs_for_generation( + num_beams, outputs.logits, outputs.past_key_values + ) + outputs.logits = logits + outputs.past_key_values = past_key_values + + next_token_logits = outputs.logits[:, -1, :] + + next_token_scores = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + probs = torch.nn.functional.softmax(next_token_scores, dim=-1) + + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if model_kwargs.get("past_key_values", None) is not None and not first_iteration: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True first_iteration = False sequence_outputs = beam_scorer.finalize( @@ -1106,18 +1224,506 @@ def _beam_search( sequence_outputs["sequence_scores"] = None return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] + def _group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs, + ): + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + device = input_ids.device + + _, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + + if return_dict_in_generate and output_scores: + beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + else: + beam_indices = None + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + first_iteration = True + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if first_iteration: + input_ids = input_ids.repeat_interleave(num_beams, dim=0) + model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) + logits, past_key_values = self._expand_outputs_for_generation( + num_beams, outputs.logits, outputs.past_key_values + ) + outputs.logits = logits + outputs.past_key_values = past_key_values + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + if output_logits: + raw_logit_score = outputs.logits[:, -1, :] + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits = outputs.logits[batch_group_indices, -1, :] + + next_token_scores = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + if output_scores: + processed_score[batch_group_indices] = next_token_scores_processed + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + + group_start_idx + + (beam_idx % group_size) + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_logits: + raw_logits += (raw_logit_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if model_kwargs.get("past_key_values", None) is not None and not first_iteration: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], reordering_indices + ) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + first_iteration = False + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = None, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + _, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(batch_size * num_beams)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + + first_iteration = True + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if first_iteration: + input_ids = input_ids.repeat_interleave(num_beams, dim=0) + model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) + logits, past_key_values = self._expand_outputs_for_generation( + num_beams, outputs.logits, outputs.past_key_values + ) + outputs.logits = logits + outputs.past_key_values = past_key_values + + next_token_logits = outputs.logits[:, -1, :] + next_token_scores = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + scores_for_all_vocab = next_token_scores.clone() + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = len(eos_token_id) if eos_token_id else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if model_kwargs.get("past_key_values", None) is not None and not first_iteration: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + first_iteration = False + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): From df8a5c647c23df72776fe94245e92944d55ea178 Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 30 Apr 2024 10:37:06 +0400 Subject: [PATCH 3/8] add test --- optimum/intel/openvino/modeling_decoder.py | 28 ++++---- tests/openvino/test_modeling.py | 81 ++++++++++++++++++++++ 2 files changed, 95 insertions(+), 14 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 4cd9bb8303..6d8a9106ad 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -1418,22 +1418,22 @@ def _group_beam_search( beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] - if return_dict_in_generate and output_scores: - beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) - ) + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") - + group_start_idx - + (beam_idx % group_size) - ) + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + + group_start_idx + + (beam_idx % group_size) + ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index d4f55c683b..72016d0840 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self): del model_with_cache gc.collect() + def test_beam_search(self): + model_id = MODEL_NAMES["llama"] + ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True) + ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + ov_model_stateful.generation_config.eos_token_id = None + ov_model_stateless.generation_config.eos_token_id = None + transformers_model.generation_config.eos_token_id = None + ov_model_stateful.config.eos_token_id = None + ov_model_stateless.config.eos_token_id = None + transformers_model.config.eos_token_id = None + + # beam search + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + ) + + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + # beam sample + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=True, + eos_token_id=None, + top_k=1, + ) + + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + + # group beam search + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + num_beam_groups=2, + diversity_penalty=0.0000001, + ) + + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + + # constrained beam search + force_word = "cat" + force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids] + gen_config = GenerationConfig( + max_new_tokens=10, + min_new_tokens=10, + num_beams=4, + do_sample=False, + eos_token_id=None, + force_words_ids=force_words_ids, + ) + + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + class OVModelForMaskedLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( From f82ff061bded80acd7781596729ce15788989bb6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 13 May 2024 11:04:53 +0400 Subject: [PATCH 4/8] do not touch decoding cycles --- optimum/intel/openvino/modeling_decoder.py | 1040 +------------------- 1 file changed, 48 insertions(+), 992 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 6d8a9106ad..cf7cf3a6fa 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -142,6 +142,8 @@ def __init__( self._pkv_precision = Type.f32 self.next_beam_idx = None self._past_length = 0 + self._first_iter_beam_search = False + self._second_iter_beam_search = False self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) @@ -389,6 +391,7 @@ def prepare_inputs( **kwargs, ) -> Dict: batch_size = input_ids.shape[0] + duplication_indices = None if self.config.model_type == "bloom": batch_size *= self.config.num_attention_heads @@ -438,7 +441,6 @@ def prepare_inputs( self.next_beam_idx = np.arange(batch_size, dtype=int) self._past_length = 0 past_len = self._get_past_length(past_key_values) - inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed if "attention_mask" in self.input_names or "position_ids" in self.input_names: @@ -468,7 +470,9 @@ def prepare_inputs( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) - return inputs + if self._first_iter_beam_search: + inputs, duplication_indices = self._deduplicate_inputs(inputs) + return inputs, duplication_indices def forward( self, @@ -480,7 +484,7 @@ def forward( ) -> CausalLMOutputWithPast: self.compile() - inputs = self.prepare_inputs( + inputs, duplication_idicies = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, @@ -488,6 +492,8 @@ def forward( **kwargs, ) + print(inputs["input_ids"].shape) + # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -511,6 +517,10 @@ def forward( else: past_key_values = None + if self._first_iter_beam_search: + logits, past_key_values = self._expand_outputs_for_generation(duplication_idicies, logits, past_key_values) + self._first_iter_beam_search = False + return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation @@ -550,34 +560,46 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg return model_inputs - @staticmethod - def _fake_expand_inputs_for_generation( - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, - **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: - """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" - - # postpone expanding inputs for next iterations - - return input_ids, model_kwargs - - def _expand_outputs_for_generation(self, expand_size: int, logits: torch.Tensor, past_key_values: Tuple): - if expand_size != 1: - logits = logits.repeat_interleave(expand_size, dim=0) + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + if indicies.shape[0] != 1: + logits = logits[indicies] if past_key_values and not self.stateful: past_key_values = tuple( tuple( - past_state.repeat(expand_size, axis=0 if not self.config.model_type == "chatglm" else 1) + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] for past_state in layer_past ) for layer_past in past_key_values ) if self.stateful: - self.next_beam_idx = self.next_beam_idx.repeat(expand_size) + self.next_beam_idx = self.next_beam_idx[indicies] + self._second_iter_beam_search = True return logits, past_key_values + def _deduplicate_inputs(self, model_inputs: Dict): + input_ids = model_inputs["input_ids"] + upd_model_inputs = {} + unique_input_ids, indicies, reverse_indicies = np.unique( + input_ids, axis=0, return_index=True, return_inverse=True + ) + for input_name, input_tensor in model_inputs.items(): + if input_name not in ["input_ids", "beam_idx"]: + if not isinstance(input_tensor, Tensor): + upd_model_inputs[input_name] = input_tensor[indicies] + else: + shape = input_tensor.shape + dtype = input_tensor.element_type + shape[0 if not self.config.model_type == "chatglm" else 1] = indicies.shape[0] + upd_model_inputs[input_name] = Tensor(dtype, shape) + print(f"{input_name}: {upd_model_inputs[input_name].shape}") + upd_model_inputs["input_ids"] = unique_input_ids + if "beam_idx" in model_inputs: + beam_idx = np.arange(unique_input_ids.shape[0], dtype=int) + upd_model_inputs["beam_idx"] = beam_idx + return upd_model_inputs, reverse_indicies + @torch.no_grad() def generate( self, @@ -602,10 +624,8 @@ def generate( GenerationMode.GROUP_BEAM_SEARCH, GenerationMode.CONSTRAINED_BEAM_SEARCH, ] - if is_beam_search: - self._orig_expand_inputs_for_generation = self._expand_inputs_for_generation - self._expand_inputs_for_generation = self._fake_expand_inputs_for_generation + self._first_iter_beam_search = True result = super().generate( inputs, generation_config, @@ -619,27 +639,8 @@ def generate( negative_prompt_attention_mask, **kwargs, ) - - if is_beam_search: - self._expand_inputs_for_generation = self._orig_expand_inputs_for_generation return result - def _update_inputs_for_beam_search(self, model_inputs: Dict[str, torch.Tensor], input_expand_size: int): - def _expand_dict_for_generation(dict_to_expand): - for key in dict_to_expand: - if dict_to_expand[key] is not None: - if isinstance(dict_to_expand[key], torch.Tensor): - dict_to_expand[key] = dict_to_expand[key].repeat_interleave(input_expand_size, dim=0) - if isinstance(dict_to_expand[key], np.ndarray): - dict_to_expand[key] = dict_to_expand[key].repeat(input_expand_size, dim=0) - return dict_to_expand - - input_ids = model_inputs.pop("input_ids", None) - model_inputs = _expand_dict_for_generation(model_inputs) - if input_ids is not None: - model_inputs["input_ids"] = input_ids - return model_inputs - def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 @@ -670,7 +671,10 @@ def _reorder_cache( if self.stateful: # TODO: Apply it differently based on model type # TODO: At least for bloom we need to replicate values for each attention head - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = ( + np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + ) # save beam_idx to be used as an input in the next iteration + self._second_iter_beam_search = False return past_key_values else: return tuple( @@ -775,954 +779,6 @@ def _from_pretrained( return causal_model - def _beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - sequential: Optional[bool] = None, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - sequential = sequential if sequential is not None else self.generation_config.low_memory - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(num_beams * batch_size)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - first_iteration = True - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # if sequential is True, split the input to batches of batch_size and run sequentially - if sequential: - inputs_per_sub_batches = _split_model_inputs( - model_inputs, - split_size=batch_size, - full_batch_size=batch_beam_size if not self._first_iteration else batch_size, - ) - outputs_per_sub_batch = [ - self( - **inputs_per_sub_batch, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - for inputs_per_sub_batch in inputs_per_sub_batches - ] - - outputs = stack_model_outputs(outputs_per_sub_batch) - - else: # Unchanged original behavior - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - if first_iteration: - input_ids = input_ids.repeat_interleave(num_beams, dim=0) - model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) - logits, past_key_values = self._expand_outputs_for_generation( - num_beams, outputs.logits, outputs.past_key_values - ) - outputs.logits = logits - outputs.past_key_values = past_key_values - - next_token_logits = outputs.logits[:, -1, :] - next_token_scores = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if model_kwargs.get("past_key_values", None) is not None and not first_iteration: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - first_iteration = False - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _beam_sample( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # TODO remove when the method is totally private and beam scorer refactored - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_size * num_beams)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - first_iteration = True - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if first_iteration: - input_ids = input_ids.repeat_interleave(num_beams, dim=0) - model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) - logits, past_key_values = self._expand_outputs_for_generation( - num_beams, outputs.logits, outputs.past_key_values - ) - outputs.logits = logits - outputs.past_key_values = past_key_values - - next_token_logits = outputs.logits[:, -1, :] - - next_token_scores = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if model_kwargs.get("past_key_values", None) is not None and not first_iteration: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - first_iteration = False - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _group_beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - **model_kwargs, - ): - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - batch_size = len(beam_scorer._beam_hyps) // num_beam_groups - device = input_ids.device - - _, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - - if return_dict_in_generate and output_scores: - beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] - else: - beam_indices = None - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in - # the same group don't produce same tokens everytime. - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - first_iteration = True - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # predicted tokens in cur_len step - current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) - - # do one decoder step on all beams of all sentences in batch - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if first_iteration: - input_ids = input_ids.repeat_interleave(num_beams, dim=0) - model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) - logits, past_key_values = self._expand_outputs_for_generation( - num_beams, outputs.logits, outputs.past_key_values - ) - outputs.logits = logits - outputs.past_key_values = past_key_values - - if output_scores: - processed_score = torch.zeros_like(outputs.logits[:, -1, :]) - if output_logits: - raw_logit_score = outputs.logits[:, -1, :] - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of current group only - next_token_logits = outputs.logits[batch_group_indices, -1, :] - - next_token_scores = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx - ) - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as(next_token_scores_processed) - - if output_scores: - processed_score[batch_group_indices] = next_token_scores_processed - - # reshape for beam search - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) - - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=process_beam_indices, - group_index=beam_group_idx, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - if return_dict_in_generate and output_scores: - beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) - ) - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") - + group_start_idx - + (beam_idx % group_size) - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (processed_score,) - if output_logits: - raw_logits += (raw_logit_score,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if model_kwargs.get("past_key_values", None) is not None and not first_iteration: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], reordering_indices - ) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - first_iteration = False - - final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=final_beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - - def _constrained_beam_search( - self, - input_ids: torch.LongTensor, - constrained_beam_scorer: ConstrainedBeamSearchScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = None, - **model_kwargs, - ) -> Union[GenerateBeamOutput, torch.LongTensor]: - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - if eos_token_id is not None: - logger.warning_once( - "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" - " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." - " Otherwise make sure to set `model.generation_config.eos_token_id`", - FutureWarning, - ) - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_logits = output_logits if output_logits is not None else self.generation_config.output_logits - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - batch_size = len(constrained_beam_scorer._beam_hyps) - num_beams = constrained_beam_scorer.num_beams - - _, cur_len = input_ids.shape - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_size * num_beams)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - - first_iteration = True - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if first_iteration: - input_ids = input_ids.repeat_interleave(num_beams, dim=0) - model_kwargs = self._update_inputs_for_beam_search(model_kwargs, num_beams) - logits, past_key_values = self._expand_outputs_for_generation( - num_beams, outputs.logits, outputs.past_key_values - ) - outputs.logits = logits - outputs.past_key_values = past_key_values - - next_token_logits = outputs.logits[:, -1, :] - next_token_scores = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - scores_for_all_vocab = next_token_scores.clone() - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. - n_eos_tokens = len(eos_token_id) if eos_token_id else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = (next_tokens / vocab_size).long() - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = constrained_beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - scores_for_all_vocab, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if model_kwargs.get("past_key_values", None) is not None and not first_iteration: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - first_iteration = False - - sequence_outputs = constrained_beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation From 0dbb104c54633044d24cafa1316a5545fa1a989e Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 13 May 2024 16:26:37 +0400 Subject: [PATCH 5/8] fix stateless model support --- optimum/intel/openvino/modeling_decoder.py | 121 ++++++++++++++------- tests/openvino/test_modeling.py | 31 ++++-- 2 files changed, 105 insertions(+), 47 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index cf7cf3a6fa..1c9a0eb941 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import warnings import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino @@ -28,21 +28,10 @@ from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin -from transformers.generation.beam_search import BeamScorer, ConstrainedBeamSearchScorer from transformers.generation.configuration_utils import GenerationConfig, GenerationMode from transformers.generation.logits_process import LogitsProcessorList -from transformers.generation.stopping_criteria import ( - EosTokenCriteria, - StoppingCriteriaList, - validate_stopping_criteria, -) -from transformers.generation.utils import ( - GenerateBeamDecoderOnlyOutput, - GenerateBeamOutput, - GenerateOutput, - _split_model_inputs, - stack_model_outputs, -) +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast from optimum.utils.normalized_config import NormalizedConfigManager @@ -398,7 +387,11 @@ def prepare_inputs( inputs = {} if not self.stateful: if past_key_values is not None: - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( @@ -491,9 +484,6 @@ def forward( position_ids=position_ids, **kwargs, ) - - print(inputs["input_ids"].shape) - # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -509,7 +499,11 @@ def forward( if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) @@ -561,21 +555,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg return model_inputs def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] if indicies.shape[0] != 1: logits = logits[indicies] if past_key_values and not self.stateful: - past_key_values = tuple( - tuple( - past_state[indicies] - if not self.config.model_type == "chatglm" - else past_state[:, indicies, ...] - for past_state in layer_past + if ( + self.config.model_type not in MULTI_QUERY_ATTN_MODELS + or self.config.model_type == "falcon" + and self.config.new_decoder_architecture + ): + past_key_values = tuple( + tuple( + past_state[indicies] + if not self.config.model_type == "chatglm" + else past_state[:, indicies, ...] + for past_state in layer_past + ) + for layer_past in past_key_values ) - for layer_past in past_key_values - ) - if self.stateful: - self.next_beam_idx = self.next_beam_idx[indicies] - self._second_iter_beam_search = True + else: + past_key_values = tuple([past_state[indicies] for past_state in past_key_values]) + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True return logits, past_key_values def _deduplicate_inputs(self, model_inputs: Dict): @@ -591,12 +597,19 @@ def _deduplicate_inputs(self, model_inputs: Dict): else: shape = input_tensor.shape dtype = input_tensor.element_type - shape[0 if not self.config.model_type == "chatglm" else 1] = indicies.shape[0] + upd_batch_size = indicies.shape[0] + if self.config.model_type == "bloom": + upd_batch_size *= self.config.num_attention_heads + shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size upd_model_inputs[input_name] = Tensor(dtype, shape) - print(f"{input_name}: {upd_model_inputs[input_name].shape}") upd_model_inputs["input_ids"] = unique_input_ids if "beam_idx" in model_inputs: - beam_idx = np.arange(unique_input_ids.shape[0], dtype=int) + beam_range = ( + unique_input_ids.shape[0] + if self.config.model_type != "bloom" + else unique_input_ids.shape[0] * self.config.num_attention_heads + ) + beam_idx = np.arange(beam_range, dtype=int) upd_model_inputs["beam_idx"] = beam_idx return upd_model_inputs, reverse_indicies @@ -646,7 +659,9 @@ def _get_past_length(self, past_key_values=None): return 0 if self.stateful: return self._past_length - if self.config.model_type in MULTI_QUERY_ATTN_MODELS: + if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): return past_key_values[0].shape[-2] seq_length_dim = -2 if self.config.model_type == "chatglm": @@ -677,9 +692,14 @@ def _reorder_cache( self._second_iter_beam_search = False return past_key_values else: - return tuple( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not ( + self.config.model_type == "falcon" and self.config.new_decoder_architecture + ): + return tuple( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) + for layer_past in past_key_values + ) + return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values) def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" @@ -800,11 +820,12 @@ def _reorder_cache( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ if self.stateful: - beam_idx = np.array(beam_idx) batch_size = beam_idx.shape[0] + beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx indices = np.array(range(batch_size * self.config.num_attention_heads)) indices = indices.reshape([batch_size, self.config.num_attention_heads]) self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() + self._second_iter_beam_search = False return past_key_values else: standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) @@ -854,6 +875,24 @@ def _convert_to_standard_cache( for layer_past in past_key_value ) + def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): + batch_size = logits.shape[0] + if indicies.shape[0] != 1: + logits = logits[indicies] + if past_key_values and not self.stateful: + pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size) + pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard) + past_key_values = self._convert_to_bloom_cache(pkv) + + if self.stateful: + self.next_beam_idx = ( + self.next_beam_idx[indicies] + if self.next_beam_idx is not None + else np.arange(batch_size, dtype=int)[indicies] + ) + self._second_iter_beam_search = True + return logits, past_key_values + class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache @@ -861,7 +900,9 @@ def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: if self.stateful: - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx + self._second_iter_beam_search = False return past_key_values else: return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 72016d0840..8458c6fd9c 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -778,14 +778,31 @@ def test_default_filling_attention_mask_and_position_ids(self): del model_with_cache gc.collect() - def test_beam_search(self): - model_id = MODEL_NAMES["llama"] - ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True) - ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow + def test_beam_search(self, model_arch): + model_kwargs = {} + model_id = MODEL_NAMES[model_arch] + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + # Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search + if model_arch in ["qwen", "chatglm"]: + return - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token + ov_model_stateful = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=True, **model_kwargs + ) + ov_model_stateless = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=False, **model_kwargs + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + tokenizer.pad_token_id = tokenizer.eos_token_id tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) ov_model_stateful.generation_config.eos_token_id = None ov_model_stateless.generation_config.eos_token_id = None From 59c8c40fcf4b8666805890d8d21cd7ef090c7746 Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 14 May 2024 12:44:12 +0400 Subject: [PATCH 6/8] fix quantization --- optimum/intel/openvino/modeling_decoder.py | 2 +- optimum/intel/openvino/quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1c9a0eb941..1df33c21d3 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import warnings import logging import os +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 45961a86ff..5f4aa6571b 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -688,7 +688,7 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf nsamples = quantization_config.num_samples if quantization_config.num_samples else 128 calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples) calibration_dataset = prepare_dataset(calibration_dataset) - calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)) + calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)[0]) return calibration_dataset From daecdacab813e8dcdbf84127b5ee90247512af0f Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 15 May 2024 09:18:10 +0400 Subject: [PATCH 7/8] move inputs modification into forward --- optimum/intel/openvino/modeling_decoder.py | 12 ++++++------ optimum/intel/openvino/quantization.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1df33c21d3..e4dc1ed784 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -380,7 +380,6 @@ def prepare_inputs( **kwargs, ) -> Dict: batch_size = input_ids.shape[0] - duplication_indices = None if self.config.model_type == "bloom": batch_size *= self.config.num_attention_heads @@ -463,9 +462,7 @@ def prepare_inputs( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) - if self._first_iter_beam_search: - inputs, duplication_indices = self._deduplicate_inputs(inputs) - return inputs, duplication_indices + return inputs def forward( self, @@ -477,13 +474,16 @@ def forward( ) -> CausalLMOutputWithPast: self.compile() - inputs, duplication_idicies = self.prepare_inputs( + inputs = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, **kwargs, ) + + if self._first_iter_beam_search: + inputs, duplication_indices = self._deduplicate_inputs(inputs) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -512,7 +512,7 @@ def forward( past_key_values = None if self._first_iter_beam_search: - logits, past_key_values = self._expand_outputs_for_generation(duplication_idicies, logits, past_key_values) + logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values) self._first_iter_beam_search = False return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 5f4aa6571b..45961a86ff 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -688,7 +688,7 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf nsamples = quantization_config.num_samples if quantization_config.num_samples else 128 calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples) calibration_dataset = prepare_dataset(calibration_dataset) - calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)[0]) + calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)) return calibration_dataset From b1fc04b8b9e0cc48aeac2e36b38110debbc6afd2 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 15 May 2024 09:25:09 +0400 Subject: [PATCH 8/8] refactor test --- tests/openvino/test_modeling.py | 81 +++++++++++++-------------------- 1 file changed, 32 insertions(+), 49 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 8458c6fd9c..75c95c1563 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -793,40 +793,15 @@ def test_beam_search(self, model_arch): if model_arch in ["qwen", "chatglm"]: return - ov_model_stateful = OVModelForCausalLM.from_pretrained( - model_id, export=True, use_cache=True, stateful=True, **model_kwargs - ) - ov_model_stateless = OVModelForCausalLM.from_pretrained( - model_id, export=True, use_cache=True, stateful=False, **model_kwargs - ) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) - tokenizer.pad_token_id = tokenizer.eos_token_id - tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) - ov_model_stateful.generation_config.eos_token_id = None - ov_model_stateless.generation_config.eos_token_id = None - transformers_model.generation_config.eos_token_id = None - ov_model_stateful.config.eos_token_id = None - ov_model_stateless.config.eos_token_id = None - transformers_model.config.eos_token_id = None - - # beam search - gen_config = GenerationConfig( + beam_search_gen_config = GenerationConfig( max_new_tokens=10, min_new_tokens=10, num_beams=4, do_sample=False, eos_token_id=None, ) - - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) - ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) - ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) - # beam sample - gen_config = GenerationConfig( + beam_sample_gen_config = GenerationConfig( max_new_tokens=10, min_new_tokens=10, num_beams=4, @@ -835,14 +810,7 @@ def test_beam_search(self, model_arch): top_k=1, ) - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) - ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) - ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) - - # group beam search - gen_config = GenerationConfig( + group_beam_search_gen_config = GenerationConfig( max_new_tokens=10, min_new_tokens=10, num_beams=4, @@ -851,17 +819,9 @@ def test_beam_search(self, model_arch): num_beam_groups=2, diversity_penalty=0.0000001, ) - - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) - ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) - ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) - - # constrained beam search force_word = "cat" force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids] - gen_config = GenerationConfig( + constrained_beam_search_gen_config = GenerationConfig( max_new_tokens=10, min_new_tokens=10, num_beams=4, @@ -870,11 +830,34 @@ def test_beam_search(self, model_arch): force_words_ids=force_words_ids, ) - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) - ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) - ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) + gen_configs = [ + beam_search_gen_config, + beam_sample_gen_config, + group_beam_search_gen_config, + constrained_beam_search_gen_config, + ] + ov_model_stateful = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=True, **model_kwargs + ) + ov_model_stateless = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=True, stateful=False, **model_kwargs + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer.pad_token_id = tokenizer.eos_token_id + tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + ov_model_stateful.generation_config.eos_token_id = None + ov_model_stateless.generation_config.eos_token_id = None + transformers_model.generation_config.eos_token_id = None + ov_model_stateful.config.eos_token_id = None + ov_model_stateless.config.eos_token_id = None + transformers_model.config.eos_token_id = None + + for gen_config in gen_configs: + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) + ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs)) + ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) + self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs)) class OVModelForMaskedLMIntegrationTest(unittest.TestCase):