Skip to content

Commit 1782a50

Browse files
committed
support assisted decoding and add reorder cache function
1 parent 151712d commit 1782a50

File tree

1 file changed

+96
-3
lines changed

1 file changed

+96
-3
lines changed

optimum/intel/ipex/modeling_base.py

+96-3
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,95 @@ def __init__(
348348
if warmup:
349349
self._init_warmup()
350350

351+
def _reorder_cache(
352+
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
353+
) -> Tuple[Tuple[torch.Tensor]]:
354+
"""
355+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
356+
[`~PreTrainedModel.beam_sample`] is called.
357+
This is required to match `past_key_values` with the correct beam_idx at every generation step.
358+
"""
359+
if self.config.model_type == "bloom":
360+
return self._reorder_cache_bloom(past_key_values, beam_idx)
361+
362+
# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
363+
return tuple(
364+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
365+
for layer_past in past_key_values
366+
)
367+
368+
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
369+
def _reorder_cache_bloom(
370+
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
371+
) -> Tuple[Tuple[torch.Tensor]]:
372+
"""
373+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
374+
[`~PreTrainedModel.beam_sample`] is called for bloom architecture.
375+
This is required to match `past_key_values` with the correct beam_idx at every generation step.
376+
"""
377+
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
378+
379+
# Get a copy of `beam_idx` on all the devices where we need those indices.
380+
device_to_beam_idx = {
381+
past_state.device: beam_idx.to(past_state.device)
382+
for layer_past in past_key_values
383+
for past_state in layer_past
384+
}
385+
reordered_past = tuple(
386+
(
387+
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
388+
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
389+
)
390+
for layer_past in standardized_past
391+
)
392+
return self._convert_to_bloom_cache(reordered_past)
393+
394+
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
395+
@staticmethod
396+
def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]:
397+
"""
398+
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
399+
"""
400+
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
401+
batch_size_times_num_heads = batch_size * num_heads
402+
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
403+
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
404+
return tuple(
405+
(
406+
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
407+
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
408+
)
409+
for layer_past in past_key_value
410+
)
411+
412+
# Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
413+
def _convert_to_standard_cache(
414+
self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int
415+
) -> Tuple[Tuple[torch.Tensor]]:
416+
"""
417+
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
418+
"""
419+
if self.config.model_type != "bloom":
420+
return past_key_value
421+
422+
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
423+
num_heads = batch_size_times_num_heads // batch_size
424+
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
425+
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
426+
return tuple(
427+
(
428+
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
429+
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
430+
)
431+
for layer_past in past_key_value
432+
)
433+
351434
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
352435
past_key_values = past_key_values or kwargs.get("past", None)
353436

354437
if self.use_cache and past_key_values is not None:
355-
input_ids = input_ids[:, -1:]
438+
past_length = self.get_past_length(past_key_values)
439+
input_ids = input_ids[:, past_length:]
356440

357441
# `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
358442
if past_key_values is not None and self.config.model_type == "bloom":
@@ -368,7 +452,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
368452
position_ids = attention_mask.long().cumsum(-1) - 1
369453
position_ids.masked_fill_(attention_mask == 0, 1)
370454
if past_key_values:
371-
position_ids = position_ids[:, -1].unsqueeze(-1)
455+
position_ids = position_ids[:, -input_ids.shape[-1] :]
372456

373457
return {
374458
"input_ids": input_ids,
@@ -379,6 +463,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
379463
"token_type_ids": None,
380464
}
381465

466+
def get_past_length(self, past_key_values):
467+
model_type = self.config.model_type.replace("_", "-")
468+
if model_type == "bloom":
469+
return past_key_values[0][0].shape[-1]
470+
elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS:
471+
return past_key_values[0].shape[1]
472+
else:
473+
return past_key_values[0][0].shape[-2]
474+
382475
def _prepare_past_key_values(self, input_ids):
383476
model_type = self.config.model_type.replace("_", "-")
384477
nb_pkv = 2
@@ -431,7 +524,7 @@ def forward(
431524
position_ids = attention_mask.long().cumsum(-1) - 1
432525
position_ids.masked_fill_(attention_mask == 0, 1)
433526
if past_key_values:
434-
position_ids = position_ids[:, -1].unsqueeze(-1)
527+
position_ids = position_ids[:, -input_ids.shape[-1] :]
435528

436529
if "position_ids" in self.input_names or not self.input_names:
437530
inputs["position_ids"] = position_ids

0 commit comments

Comments
 (0)