Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize first latency beam search for OVModelForCausalLM #695

Merged
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 172 additions & 15 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import openvino
@@ -28,6 +28,10 @@
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast

from optimum.utils.normalized_config import NormalizedConfigManager
@@ -41,6 +45,11 @@
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE


if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.streamers import BaseStreamer


logger = logging.getLogger(__name__)

core = Core()
@@ -122,6 +131,8 @@ def __init__(
self._pkv_precision = Type.f32
self.next_beam_idx = None
self._past_length = 0
self._first_iter_beam_search = False
self._second_iter_beam_search = False
self.update_pkv_precision()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
@@ -369,13 +380,18 @@ def prepare_inputs(
**kwargs,
) -> Dict:
batch_size = input_ids.shape[0]
duplication_indices = None
if self.config.model_type == "bloom":
batch_size *= self.config.num_attention_heads

inputs = {}
if not self.stateful:
if past_key_values is not None:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
@@ -418,7 +434,6 @@ def prepare_inputs(
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
@@ -448,7 +463,9 @@ def prepare_inputs(
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
)

return inputs
if self._first_iter_beam_search:
inputs, duplication_indices = self._deduplicate_inputs(inputs)
return inputs, duplication_indices

def forward(
self,
@@ -460,14 +477,13 @@ def forward(
) -> CausalLMOutputWithPast:
self.compile()

inputs = self.prepare_inputs(
inputs, duplication_idicies = self.prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
**kwargs,
)

# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
@@ -483,14 +499,22 @@ def forward(
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None

if self._first_iter_beam_search:
logits, past_key_values = self._expand_outputs_for_generation(duplication_idicies, logits, past_key_values)
self._first_iter_beam_search = False

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
@@ -520,20 +544,124 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": position_ids,
"attention_mask": attention_mask,
}

return model_inputs

def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
if (
self.config.model_type not in MULTI_QUERY_ATTN_MODELS
or self.config.model_type == "falcon"
and self.config.new_decoder_architecture
):
past_key_values = tuple(
tuple(
past_state[indicies]
if not self.config.model_type == "chatglm"
else past_state[:, indicies, ...]
for past_state in layer_past
)
for layer_past in past_key_values
)
else:
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values

def _deduplicate_inputs(self, model_inputs: Dict):
input_ids = model_inputs["input_ids"]
upd_model_inputs = {}
unique_input_ids, indicies, reverse_indicies = np.unique(
input_ids, axis=0, return_index=True, return_inverse=True
)
for input_name, input_tensor in model_inputs.items():
if input_name not in ["input_ids", "beam_idx"]:
if not isinstance(input_tensor, Tensor):
upd_model_inputs[input_name] = input_tensor[indicies]
else:
shape = input_tensor.shape
dtype = input_tensor.element_type
upd_batch_size = indicies.shape[0]
if self.config.model_type == "bloom":
upd_batch_size *= self.config.num_attention_heads
shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size
upd_model_inputs[input_name] = Tensor(dtype, shape)
upd_model_inputs["input_ids"] = unique_input_ids
if "beam_idx" in model_inputs:
beam_range = (
unique_input_ids.shape[0]
if self.config.model_type != "bloom"
else unique_input_ids.shape[0] * self.config.num_attention_heads
)
beam_idx = np.arange(beam_range, dtype=int)
upd_model_inputs["beam_idx"] = beam_idx
return upd_model_inputs, reverse_indicies

@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
generation_mode = _generation_config.get_generation_mode(assistant_model)

is_beam_search = generation_mode in [
GenerationMode.BEAM_SEARCH,
GenerationMode.BEAM_SAMPLE,
GenerationMode.GROUP_BEAM_SEARCH,
GenerationMode.CONSTRAINED_BEAM_SEARCH,
]
if is_beam_search:
self._first_iter_beam_search = True
result = super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
assistant_model,
streamer,
negative_prompt_ids,
negative_prompt_attention_mask,
**kwargs,
)
return result

def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
if self.stateful:
return self._past_length
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm":
@@ -558,12 +686,20 @@ def _reorder_cache(
if self.stateful:
# TODO: Apply it differently based on model type
# TODO: At least for bloom we need to replicate values for each attention head
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
self.next_beam_idx = (
np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
) # save beam_idx to be used as an input in the next iteration
self._second_iter_beam_search = False
return past_key_values
else:
return tuple(
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values
)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS and not (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return tuple(
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
for layer_past in past_key_values
)
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -684,11 +820,12 @@ def _reorder_cache(
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.stateful:
beam_idx = np.array(beam_idx)
batch_size = beam_idx.shape[0]
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
indices = np.array(range(batch_size * self.config.num_attention_heads))
indices = indices.reshape([batch_size, self.config.num_attention_heads])
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
self._second_iter_beam_search = False
return past_key_values
else:
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
@@ -738,14 +875,34 @@ def _convert_to_standard_cache(
for layer_past in past_key_value
)

def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
past_key_values = self._convert_to_bloom_cache(pkv)

if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values


class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
if self.stateful:
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
# save beam_idx to be used as an input in the next iteration
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
self._second_iter_beam_search = False
return past_key_values
else:
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)
2 changes: 1 addition & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
@@ -688,7 +688,7 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf
nsamples = quantization_config.num_samples if quantization_config.num_samples else 128
calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
calibration_dataset = prepare_dataset(calibration_dataset)
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x))
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x)[0])

return calibration_dataset

Loading
Loading