Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize first latency beam search for OVModelForCausalLM #695

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 169 additions & 12 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import openvino
Expand All @@ -28,6 +28,10 @@
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.utils.normalized_config import NormalizedConfigManager
Expand All @@ -41,6 +45,11 @@
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE


if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.streamers import BaseStreamer


logger = logging.getLogger(__name__)

core = Core()
Expand Down Expand Up @@ -122,6 +131,8 @@ def __init__(
self._pkv_precision = Type.f32
self.next_beam_idx = None
self._past_length = 0
self._first_iter_beam_search = False
self._second_iter_beam_search = False
self.update_pkv_precision()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
Expand Down Expand Up @@ -375,7 +386,11 @@ def prepare_inputs(
inputs = {}
if not self.stateful:
if past_key_values is not None:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
Expand Down Expand Up @@ -418,7 +433,6 @@ def prepare_inputs(
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
Expand Down Expand Up @@ -468,6 +482,8 @@ def forward(
**kwargs,
)

if self._first_iter_beam_search:
inputs, duplication_indices = self._deduplicate_inputs(inputs)
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
Expand All @@ -483,14 +499,22 @@ def forward(
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None

if self._first_iter_beam_search:
logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values)
self._first_iter_beam_search = False

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
Expand Down Expand Up @@ -520,20 +544,124 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": position_ids,
"attention_mask": attention_mask,
}

return model_inputs

def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
past_key_values = tuple(
tuple(
past_state[indicies]
if not self.config.model_type == "chatglm"
else past_state[:, indicies, ...]
for past_state in layer_past
)
for layer_past in past_key_values
)
else:
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values

def _deduplicate_inputs(self, model_inputs: Dict):
input_ids = model_inputs["input_ids"]
upd_model_inputs = {}
unique_input_ids, indicies, reverse_indicies = np.unique(
input_ids, axis=0, return_index=True, return_inverse=True
)
for input_name, input_tensor in model_inputs.items():
if input_name not in ["input_ids", "beam_idx"]:
if not isinstance(input_tensor, Tensor):
upd_model_inputs[input_name] = input_tensor[indicies]
else:
shape = input_tensor.shape
dtype = input_tensor.element_type
upd_batch_size = indicies.shape[0]
if self.config.model_type == "bloom":
upd_batch_size *= self.config.num_attention_heads
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
upd_model_inputs[input_name] = Tensor(dtype, shape)
upd_model_inputs["input_ids"] = unique_input_ids
if "beam_idx" in model_inputs:
beam_range = (
unique_input_ids.shape[0]
if self.config.model_type != "bloom"
else unique_input_ids.shape[0] * self.config.num_attention_heads
)
beam_idx = np.arange(beam_range, dtype=int)
upd_model_inputs["beam_idx"] = beam_idx
return upd_model_inputs, reverse_indicies

@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
generation_mode = _generation_config.get_generation_mode(assistant_model)

is_beam_search = generation_mode in [
GenerationMode.BEAM_SEARCH,
GenerationMode.BEAM_SAMPLE,
GenerationMode.GROUP_BEAM_SEARCH,
GenerationMode.CONSTRAINED_BEAM_SEARCH,
]
if is_beam_search:
self._first_iter_beam_search = True
result = super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
assistant_model,
streamer,
negative_prompt_ids,
negative_prompt_attention_mask,
**kwargs,
)
return result

def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
if self.stateful:
return self._past_length
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm":
Expand All @@ -558,12 +686,20 @@ def _reorder_cache(
if self.stateful:
# TODO: Apply it differently based on model type
# TODO: At least for bloom we need to replicate values for each attention head
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
self.next_beam_idx = (
np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
) # save beam_idx to be used as an input in the next iteration
self._second_iter_beam_search = False
return past_key_values
else:
return tuple(
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return tuple(
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
for layer_past in past_key_values
)
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
Expand Down Expand Up @@ -684,11 +820,12 @@ def _reorder_cache(
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.stateful:
beam_idx = np.array(beam_idx)
batch_size = beam_idx.shape[0]
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
indices = np.array(range(batch_size * self.config.num_attention_heads))
indices = indices.reshape([batch_size, self.config.num_attention_heads])
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
self._second_iter_beam_search = False
return past_key_values
else:
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
Expand Down Expand Up @@ -738,14 +875,34 @@ def _convert_to_standard_cache(
for layer_past in past_key_value
)

def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
past_key_values = self._convert_to_bloom_cache(pkv)

if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values


class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
if self.stateful:
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
# save beam_idx to be used as an input in the next iteration
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
self._second_iter_beam_search = False
return past_key_values
else:
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)
81 changes: 81 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
del model_with_cache
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@pytest.mark.run_slow
@slow
def test_beam_search(self, model_arch):
model_kwargs = {}
model_id = MODEL_NAMES[model_arch]
if model_arch in self.REMOTE_CODE_MODELS:
model_kwargs = {
"config": AutoConfig.from_pretrained(model_id, trust_remote_code=True),
"trust_remote_code": True,
}
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
if model_arch in ["qwen", "chatglm"]:
return

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
beam_search_gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
)
beam_sample_gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=True,
eos_token_id=None,
top_k=1,
)

group_beam_search_gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
num_beam_groups=2,
diversity_penalty=0.0000001,
)
force_word = "cat"
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
constrained_beam_search_gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
force_words_ids=force_words_ids,
)

gen_configs = [
beam_search_gen_config,
beam_sample_gen_config,
group_beam_search_gen_config,
constrained_beam_search_gen_config,
]
ov_model_stateful = OVModelForCausalLM.from_pretrained(
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
)
ov_model_stateless = OVModelForCausalLM.from_pretrained(
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
ov_model_stateful.generation_config.eos_token_id = None
ov_model_stateless.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model_stateful.config.eos_token_id = None
ov_model_stateless.config.eos_token_id = None
transformers_model.config.eos_token_id = None

for gen_config in gen_configs:
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down
Loading