Skip to content

Commit 2b902bb

Browse files
authored
Optimize first latency beam search for OVModelForCausalLM (#695)
* WIP: beam search only * other beam search algos * add test * do not touch decoding cycles * fix stateless model support * fix quantization * move inputs modification into forward * refactor test
1 parent bfc8663 commit 2b902bb

File tree

2 files changed

+250
-12
lines changed

2 files changed

+250
-12
lines changed

optimum/intel/openvino/modeling_decoder.py

+169-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from pathlib import Path
1919
from tempfile import TemporaryDirectory
20-
from typing import Dict, Optional, Tuple, Union
20+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import numpy as np
2323
import openvino
@@ -28,6 +28,10 @@
2828
from transformers import AutoModelForCausalLM, PretrainedConfig
2929
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
3030
from transformers.generation import GenerationMixin
31+
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
32+
from transformers.generation.logits_process import LogitsProcessorList
33+
from transformers.generation.stopping_criteria import StoppingCriteriaList
34+
from transformers.generation.utils import GenerateOutput
3135
from transformers.modeling_outputs import CausalLMOutputWithPast
3236

3337
from optimum.utils.normalized_config import NormalizedConfigManager
@@ -41,6 +45,11 @@
4145
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
4246

4347

48+
if TYPE_CHECKING:
49+
from transformers.modeling_utils import PreTrainedModel
50+
from transformers.streamers import BaseStreamer
51+
52+
4453
logger = logging.getLogger(__name__)
4554

4655
core = Core()
@@ -122,6 +131,8 @@ def __init__(
122131
self._pkv_precision = Type.f32
123132
self.next_beam_idx = None
124133
self._past_length = 0
134+
self._first_iter_beam_search = False
135+
self._second_iter_beam_search = False
125136
self.update_pkv_precision()
126137
if self.is_dynamic:
127138
self.model = self._reshape(self.model, -1, -1)
@@ -375,7 +386,11 @@ def prepare_inputs(
375386
inputs = {}
376387
if not self.stateful:
377388
if past_key_values is not None:
378-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
389+
if (
390+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
391+
or self.config.model_type == "falcon"
392+
and self.config.new_decoder_architecture
393+
):
379394
if self._pkv_precision == Type.bf16:
380395
# numpy does not support bf16, pretending f16, should change to bf16
381396
past_key_values = tuple(
@@ -418,7 +433,6 @@ def prepare_inputs(
418433
self.next_beam_idx = np.arange(batch_size, dtype=int)
419434
self._past_length = 0
420435
past_len = self._get_past_length(past_key_values)
421-
422436
inputs["input_ids"] = np.array(input_ids)
423437
# Add the attention_mask inputs when needed
424438
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
@@ -468,6 +482,8 @@ def forward(
468482
**kwargs,
469483
)
470484

485+
if self._first_iter_beam_search:
486+
inputs, duplication_indices = self._deduplicate_inputs(inputs)
471487
# Run inference
472488
self.request.start_async(inputs, share_inputs=True)
473489
self.request.wait()
@@ -483,14 +499,22 @@ def forward(
483499
if self.use_cache:
484500
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
485501
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
486-
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
502+
if (
503+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
504+
or self.config.model_type == "falcon"
505+
and self.config.new_decoder_architecture
506+
):
487507
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
488508
past_key_values = tuple(
489509
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
490510
)
491511
else:
492512
past_key_values = None
493513

514+
if self._first_iter_beam_search:
515+
logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values)
516+
self._first_iter_beam_search = False
517+
494518
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
495519

496520
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
@@ -520,20 +544,124 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
520544
if past_key_values:
521545
position_ids = position_ids[:, -input_ids.shape[1] :]
522546

523-
return {
547+
model_inputs = {
524548
"input_ids": input_ids,
525549
"past_key_values": past_key_values,
526550
"use_cache": use_cache,
527551
"position_ids": position_ids,
528552
"attention_mask": attention_mask,
529553
}
530554

555+
return model_inputs
556+
557+
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
558+
batch_size = logits.shape[0]
559+
if indicies.shape[0] != 1:
560+
logits = logits[indicies]
561+
if past_key_values and not self.stateful:
562+
if (
563+
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
564+
or self.config.model_type == "falcon"
565+
and self.config.new_decoder_architecture
566+
):
567+
past_key_values = tuple(
568+
tuple(
569+
past_state[indicies]
570+
if not self.config.model_type == "chatglm"
571+
else past_state[:, indicies, ...]
572+
for past_state in layer_past
573+
)
574+
for layer_past in past_key_values
575+
)
576+
else:
577+
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
578+
if self.stateful:
579+
self.next_beam_idx = (
580+
self.next_beam_idx[indicies]
581+
if self.next_beam_idx is not None
582+
else np.arange(batch_size, dtype=int)[indicies]
583+
)
584+
self._second_iter_beam_search = True
585+
return logits, past_key_values
586+
587+
def _deduplicate_inputs(self, model_inputs: Dict):
588+
input_ids = model_inputs["input_ids"]
589+
upd_model_inputs = {}
590+
unique_input_ids, indicies, reverse_indicies = np.unique(
591+
input_ids, axis=0, return_index=True, return_inverse=True
592+
)
593+
for input_name, input_tensor in model_inputs.items():
594+
if input_name not in ["input_ids", "beam_idx"]:
595+
if not isinstance(input_tensor, Tensor):
596+
upd_model_inputs[input_name] = input_tensor[indicies]
597+
else:
598+
shape = input_tensor.shape
599+
dtype = input_tensor.element_type
600+
upd_batch_size = indicies.shape[0]
601+
if self.config.model_type == "bloom":
602+
upd_batch_size *= self.config.num_attention_heads
603+
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
604+
upd_model_inputs[input_name] = Tensor(dtype, shape)
605+
upd_model_inputs["input_ids"] = unique_input_ids
606+
if "beam_idx" in model_inputs:
607+
beam_range = (
608+
unique_input_ids.shape[0]
609+
if self.config.model_type != "bloom"
610+
else unique_input_ids.shape[0] * self.config.num_attention_heads
611+
)
612+
beam_idx = np.arange(beam_range, dtype=int)
613+
upd_model_inputs["beam_idx"] = beam_idx
614+
return upd_model_inputs, reverse_indicies
615+
616+
@torch.no_grad()
617+
def generate(
618+
self,
619+
inputs: Optional[torch.Tensor] = None,
620+
generation_config: Optional[GenerationConfig] = None,
621+
logits_processor: Optional[LogitsProcessorList] = None,
622+
stopping_criteria: Optional[StoppingCriteriaList] = None,
623+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
624+
synced_gpus: Optional[bool] = None,
625+
assistant_model: Optional["PreTrainedModel"] = None,
626+
streamer: Optional["BaseStreamer"] = None,
627+
negative_prompt_ids: Optional[torch.Tensor] = None,
628+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
629+
**kwargs,
630+
) -> Union[GenerateOutput, torch.LongTensor]:
631+
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
632+
generation_mode = _generation_config.get_generation_mode(assistant_model)
633+
634+
is_beam_search = generation_mode in [
635+
GenerationMode.BEAM_SEARCH,
636+
GenerationMode.BEAM_SAMPLE,
637+
GenerationMode.GROUP_BEAM_SEARCH,
638+
GenerationMode.CONSTRAINED_BEAM_SEARCH,
639+
]
640+
if is_beam_search:
641+
self._first_iter_beam_search = True
642+
result = super().generate(
643+
inputs,
644+
generation_config,
645+
logits_processor,
646+
stopping_criteria,
647+
prefix_allowed_tokens_fn,
648+
synced_gpus,
649+
assistant_model,
650+
streamer,
651+
negative_prompt_ids,
652+
negative_prompt_attention_mask,
653+
**kwargs,
654+
)
655+
return result
656+
531657
def _get_past_length(self, past_key_values=None):
532658
if past_key_values is None:
533659
return 0
534660
if self.stateful:
535661
return self._past_length
536-
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
662+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
663+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
664+
):
537665
return past_key_values[0].shape[-2]
538666
seq_length_dim = -2
539667
if self.config.model_type == "chatglm":
@@ -558,12 +686,20 @@ def _reorder_cache(
558686
if self.stateful:
559687
# TODO: Apply it differently based on model type
560688
# TODO: At least for bloom we need to replicate values for each attention head
561-
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
689+
self.next_beam_idx = (
690+
np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
691+
) # save beam_idx to be used as an input in the next iteration
692+
self._second_iter_beam_search = False
562693
return past_key_values
563694
else:
564-
return tuple(
565-
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
566-
)
695+
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
696+
self.config.model_type == "falcon" and self.config.new_decoder_architecture
697+
):
698+
return tuple(
699+
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
700+
for layer_past in past_key_values
701+
)
702+
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)
567703

568704
def can_generate(self):
569705
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -684,11 +820,12 @@ def _reorder_cache(
684820
This is required to match `past_key_values` with the correct beam_idx at every generation step.
685821
"""
686822
if self.stateful:
687-
beam_idx = np.array(beam_idx)
688823
batch_size = beam_idx.shape[0]
824+
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
689825
indices = np.array(range(batch_size * self.config.num_attention_heads))
690826
indices = indices.reshape([batch_size, self.config.num_attention_heads])
691827
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
828+
self._second_iter_beam_search = False
692829
return past_key_values
693830
else:
694831
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
@@ -738,14 +875,34 @@ def _convert_to_standard_cache(
738875
for layer_past in past_key_value
739876
)
740877

878+
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
879+
batch_size = logits.shape[0]
880+
if indicies.shape[0] != 1:
881+
logits = logits[indicies]
882+
if past_key_values and not self.stateful:
883+
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
884+
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
885+
past_key_values = self._convert_to_bloom_cache(pkv)
886+
887+
if self.stateful:
888+
self.next_beam_idx = (
889+
self.next_beam_idx[indicies]
890+
if self.next_beam_idx is not None
891+
else np.arange(batch_size, dtype=int)[indicies]
892+
)
893+
self._second_iter_beam_search = True
894+
return logits, past_key_values
895+
741896

742897
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
743898
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
744899
def _reorder_cache(
745900
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
746901
) -> Tuple[Tuple[torch.Tensor]]:
747902
if self.stateful:
748-
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
903+
# save beam_idx to be used as an input in the next iteration
904+
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
905+
self._second_iter_beam_search = False
749906
return past_key_values
750907
else:
751908
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)

tests/openvino/test_modeling.py

+81
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
778778
del model_with_cache
779779
gc.collect()
780780

781+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
782+
@pytest.mark.run_slow
783+
@slow
784+
def test_beam_search(self, model_arch):
785+
model_kwargs = {}
786+
model_id = MODEL_NAMES[model_arch]
787+
if model_arch in self.REMOTE_CODE_MODELS:
788+
model_kwargs = {
789+
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
790+
"trust_remote_code": True,
791+
}
792+
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
793+
if model_arch in ["qwen", "chatglm"]:
794+
return
795+
796+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
797+
beam_search_gen_config = GenerationConfig(
798+
max_new_tokens=10,
799+
min_new_tokens=10,
800+
num_beams=4,
801+
do_sample=False,
802+
eos_token_id=None,
803+
)
804+
beam_sample_gen_config = GenerationConfig(
805+
max_new_tokens=10,
806+
min_new_tokens=10,
807+
num_beams=4,
808+
do_sample=True,
809+
eos_token_id=None,
810+
top_k=1,
811+
)
812+
813+
group_beam_search_gen_config = GenerationConfig(
814+
max_new_tokens=10,
815+
min_new_tokens=10,
816+
num_beams=4,
817+
do_sample=False,
818+
eos_token_id=None,
819+
num_beam_groups=2,
820+
diversity_penalty=0.0000001,
821+
)
822+
force_word = "cat"
823+
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
824+
constrained_beam_search_gen_config = GenerationConfig(
825+
max_new_tokens=10,
826+
min_new_tokens=10,
827+
num_beams=4,
828+
do_sample=False,
829+
eos_token_id=None,
830+
force_words_ids=force_words_ids,
831+
)
832+
833+
gen_configs = [
834+
beam_search_gen_config,
835+
beam_sample_gen_config,
836+
group_beam_search_gen_config,
837+
constrained_beam_search_gen_config,
838+
]
839+
ov_model_stateful = OVModelForCausalLM.from_pretrained(
840+
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
841+
)
842+
ov_model_stateless = OVModelForCausalLM.from_pretrained(
843+
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
844+
)
845+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
846+
tokenizer.pad_token_id = tokenizer.eos_token_id
847+
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
848+
ov_model_stateful.generation_config.eos_token_id = None
849+
ov_model_stateless.generation_config.eos_token_id = None
850+
transformers_model.generation_config.eos_token_id = None
851+
ov_model_stateful.config.eos_token_id = None
852+
ov_model_stateless.config.eos_token_id = None
853+
transformers_model.config.eos_token_id = None
854+
855+
for gen_config in gen_configs:
856+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
857+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
858+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
859+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
860+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
861+
781862

782863
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
783864
SUPPORTED_ARCHITECTURES = (

0 commit comments

Comments
 (0)