Skip to content

Commit a48e0ca

Browse files
authored
Rework inputs preparation for OVModelForCausalLM (#620)
* refactor OVModelForCausalLM class * rework prepare_inputs_for_generation for OVModelForCausalLM * refactoring * Apply suggestions from code review * fix position ids and add tests
1 parent 447ef50 commit a48e0ca

File tree

2 files changed

+45
-56
lines changed

2 files changed

+45
-56
lines changed

optimum/intel/openvino/modeling_decoder.py

+40-56
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
self._original_model = self.model.clone() # keep original model for serialization
121121
self._pkv_precision = Type.f32
122122
self.next_beam_idx = None
123+
self._past_length = 0
123124
self.update_pkv_precision()
124125
if self.is_dynamic:
125126
self.model = self._reshape(self.model, -1, -1)
@@ -356,19 +357,14 @@ def prepare_inputs(
356357
position_ids: Optional[torch.LongTensor] = None,
357358
**kwargs,
358359
) -> Dict:
359-
if self.use_cache and past_key_values is not None:
360-
input_ids = input_ids[:, -1:]
361-
362360
batch_size = input_ids.shape[0]
363361
if self.config.model_type == "bloom":
364362
batch_size *= self.config.num_attention_heads
365363

366364
inputs = {}
367-
past_len = 0
368365
if not self.stateful:
369366
if past_key_values is not None:
370367
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
371-
past_len = past_key_values[0][1].shape[-2]
372368
if self._pkv_precision == Type.bf16:
373369
# numpy does not support bf16, pretending f16, should change to bf16
374370
past_key_values = tuple(
@@ -381,8 +377,6 @@ def prepare_inputs(
381377
past_key_values = tuple(
382378
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
383379
)
384-
else:
385-
past_len = past_key_values[0].shape[-2]
386380

387381
# Add the past_key_values to the decoder inputs
388382
inputs = dict(zip(self.key_value_input_names, past_key_values))
@@ -411,6 +405,8 @@ def prepare_inputs(
411405
# Set initial value for the next beam_idx input that will be used at the current iteration
412406
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
413407
self.next_beam_idx = np.arange(batch_size, dtype=int)
408+
self._past_length = 0
409+
past_len = self._get_past_length(past_key_values)
414410

415411
inputs["input_ids"] = np.array(input_ids)
416412
# Add the attention_mask inputs when needed
@@ -432,7 +428,7 @@ def prepare_inputs(
432428
position_ids = np.cumsum(attention_mask, axis=1) - 1
433429
position_ids[attention_mask == 0] = 1
434430
if past_key_values:
435-
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
431+
position_ids = position_ids[:, -input_ids.shape[1] :]
436432

437433
inputs["position_ids"] = position_ids
438434

@@ -470,6 +466,7 @@ def forward(
470466
# the first condition at the function beginning above.
471467
# It should be something that is not None and it should be True when converted to Boolean.
472468
past_key_values = ((),)
469+
self._past_length += input_ids.shape[1]
473470

474471
if not self.stateful:
475472
if self.use_cache:
@@ -485,19 +482,32 @@ def forward(
485482

486483
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
487484

488-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
485+
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
489486
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
490487
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
491488
attention_mask = kwargs.get("attention_mask", None)
492489
use_cache = kwargs.get("use_cache", None)
493490

491+
if past_key_values is not None:
492+
past_len = self._get_past_length(past_key_values)
493+
# Keep only the unprocessed tokens:
494+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
495+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
496+
# input)
497+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
498+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
499+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
500+
# input_ids based on the past_length.
501+
elif past_len < input_ids.shape[1]:
502+
input_ids = input_ids[:, past_len:]
503+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
494504
position_ids = kwargs.get("position_ids", None)
495-
if attention_mask is not None and position_ids is None:
505+
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
496506
# create position_ids on the fly for batch generation
497507
position_ids = attention_mask.long().cumsum(-1) - 1
498508
position_ids.masked_fill_(attention_mask == 0, 1)
499509
if past_key_values:
500-
position_ids = position_ids[:, -1].unsqueeze(-1)
510+
position_ids = position_ids[:, -input_ids.shape[1] :]
501511

502512
return {
503513
"input_ids": input_ids,
@@ -507,6 +517,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
507517
"attention_mask": attention_mask,
508518
}
509519

520+
def _get_past_length(self, past_key_values=None):
521+
if past_key_values is None:
522+
return 0
523+
if self.stateful:
524+
return self._past_length
525+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
526+
return past_key_values[0].shape[-2]
527+
seq_length_dim = -2
528+
if self.config.model_type == "chatglm":
529+
seq_length_dim = 0
530+
elif self.config.model_type == "qwen":
531+
seq_length_dim = 1
532+
# input is tuple of pairs
533+
if isinstance(past_key_values[0], (tuple, list)):
534+
return past_key_values[0][1].shape[seq_length_dim]
535+
# past key values comes after flattening
536+
return past_key_values[1].shape[seq_length_dim]
537+
510538
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
511539
def _reorder_cache(
512540
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
@@ -573,10 +601,6 @@ def _from_pretrained(
573601
model_type = config.model_type.replace("_", "-")
574602
if model_type == "bloom":
575603
init_cls = OVBloomForCausalLM
576-
elif model_type == "mpt":
577-
init_cls = OVMPTForCausalLM
578-
elif model_type == "opt":
579-
init_cls = OVOPTForCausalLM
580604
elif model_type == "gpt-bigcode":
581605
init_cls = OVGPTBigCodeForCausalLM
582606
else:
@@ -630,22 +654,12 @@ def _from_pretrained(
630654
class OVBloomForCausalLM(OVModelForCausalLM):
631655
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
632656
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
633-
attention_mask = kwargs.get("attention_mask", None)
634-
use_cache = kwargs.get("use_cache", None)
635-
636657
# only last token for input_ids if past is not None
637658
if past_key_values and not self.stateful:
638659
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
639660
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
640661
past_key_values = self._convert_to_bloom_cache(past_key_values)
641-
642-
return {
643-
"input_ids": input_ids,
644-
"past_key_values": past_key_values,
645-
"use_cache": use_cache,
646-
"position_ids": None,
647-
"attention_mask": attention_mask,
648-
}
662+
return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)
649663

650664
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
651665
def _reorder_cache(
@@ -712,36 +726,6 @@ def _convert_to_standard_cache(
712726
)
713727

714728

715-
class OVOPTForCausalLM(OVModelForCausalLM):
716-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
717-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
718-
attention_mask = kwargs.get("attention_mask", None)
719-
use_cache = kwargs.get("use_cache", None)
720-
721-
return {
722-
"input_ids": input_ids,
723-
"past_key_values": past_key_values,
724-
"use_cache": use_cache,
725-
"position_ids": None,
726-
"attention_mask": attention_mask,
727-
}
728-
729-
730-
class OVMPTForCausalLM(OVModelForCausalLM):
731-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
732-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
733-
attention_mask = kwargs.get("attention_mask", None)
734-
use_cache = kwargs.get("use_cache", None)
735-
736-
return {
737-
"input_ids": input_ids,
738-
"past_key_values": past_key_values,
739-
"use_cache": use_cache,
740-
"position_ids": None,
741-
"attention_mask": attention_mask,
742-
}
743-
744-
745729
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
746730
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
747731
def _reorder_cache(

tests/openvino/test_modeling.py

+5
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,11 @@ def test_multiple_inputs(self, model_arch):
632632
outputs = model.generate(**tokens, generation_config=generation_config)
633633
self.assertIsInstance(outputs, torch.Tensor)
634634
self.assertEqual(outputs.shape[0], 3)
635+
# test that generation result is reproducible
636+
outputs2 = model.generate(**tokens, generation_config=generation_config)
637+
self.assertIsInstance(outputs2, torch.Tensor)
638+
self.assertEqual(outputs2.shape[0], 3)
639+
self.assertTrue(torch.allclose(outputs2, outputs))
635640
del model
636641
gc.collect()
637642

0 commit comments

Comments
 (0)