Skip to content

Commit d375995

Browse files
committed
refactor OVModelForCausalLM class
1 parent d2f9fdb commit d375995

File tree

1 file changed

+31
-54
lines changed

1 file changed

+31
-54
lines changed

optimum/intel/openvino/modeling_decoder.py

+31-54
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self._original_model = self.model.clone() # keep original model for serialization
137137
self._pkv_precision = Type.f32
138138
self.next_beam_idx = None
139+
self.past_len = 0
139140
self.update_pkv_precision()
140141
if self.is_dynamic:
141142
self.model = self._reshape(self.model, -1, -1)
@@ -377,19 +378,21 @@ def prepare_inputs(
377378
position_ids: Optional[torch.LongTensor] = None,
378379
**kwargs,
379380
) -> Dict:
380-
if self.use_cache and past_key_values is not None:
381-
input_ids = input_ids[:, -1:]
382381

383382
batch_size = input_ids.shape[0]
384383
if self.config.model_type == "bloom":
385384
batch_size *= self.config.num_attention_heads
386385

387386
inputs = {}
388-
past_len = 0
389387
if not self.stateful:
390388
if past_key_values is not None:
391389
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
392-
past_len = past_key_values[0][1].shape[-2]
390+
seq_len_dim = -2
391+
if self.config.model_type == "chatglm":
392+
seq_len_dim = 0
393+
elif self.config.model_type == "qwen":
394+
seq_len_dim = 1
395+
self.past_len = past_key_values[0][1].shape[seq_len_dim]
393396
if self._pkv_precision == Type.bf16:
394397
# numpy does not support bf16, pretending f16, should change to bf16
395398
past_key_values = tuple(
@@ -403,13 +406,14 @@ def prepare_inputs(
403406
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
404407
)
405408
else:
406-
past_len = past_key_values[0].shape[-2]
409+
self.past_len = past_key_values[0].shape[-2]
407410

408411
# Add the past_key_values to the decoder inputs
409412
inputs = dict(zip(self.key_value_input_names, past_key_values))
410413

411414
# Create empty past_key_values for decoder_with_past first generation step
412415
elif self.use_cache:
416+
self.past_len = 0
413417
for input_name in self.key_value_input_names:
414418
model_inputs = self.model.input(input_name)
415419
shape = model_inputs.get_partial_shape()
@@ -432,6 +436,7 @@ def prepare_inputs(
432436
# Set initial value for the next beam_idx input that will be used at the current iteration
433437
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
434438
self.next_beam_idx = np.arange(batch_size, dtype=int)
439+
self.past_len = 0
435440

436441
inputs["input_ids"] = np.array(input_ids)
437442
# Add the attention_mask inputs when needed
@@ -440,7 +445,7 @@ def prepare_inputs(
440445
attention_mask = np.array(attention_mask)
441446
else:
442447
attention_mask = np.ones(
443-
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
448+
(input_ids.shape[0], input_ids.shape[1] + self.past_len), dtype=inputs["input_ids"].dtype
444449
)
445450

446451
if "attention_mask" in self.input_names:
@@ -491,6 +496,7 @@ def forward(
491496
# the first condition at the function beginning above.
492497
# It should be something that is not None and it should be True when converted to Boolean.
493498
past_key_values = ((),)
499+
self.past_len += input_ids.shape[1]
494500

495501
if not self.stateful:
496502
if self.use_cache:
@@ -501,24 +507,38 @@ def forward(
501507
past_key_values = tuple(
502508
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
503509
)
510+
self.past_len += input_ids.shape[1]
504511
else:
505512
past_key_values = None
513+
self.past_len = 0
506514

507515
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
508516

509-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
517+
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
510518
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
511519
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
512520
attention_mask = kwargs.get("attention_mask", None)
513521
use_cache = kwargs.get("use_cache", None)
514522

523+
if past_key_values is not None:
524+
# Keep only the unprocessed tokens:
525+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
526+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
527+
# input)
528+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
529+
input_ids = input_ids[:, -(attention_mask.shape[1] - self.past_len) :]
530+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
531+
# input_ids based on the past_length.
532+
elif self.past_len < input_ids.shape[1]:
533+
input_ids = input_ids[:, self.past_len:]
534+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
515535
position_ids = kwargs.get("position_ids", None)
516-
if attention_mask is not None and position_ids is None:
536+
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
517537
# create position_ids on the fly for batch generation
518538
position_ids = attention_mask.long().cumsum(-1) - 1
519539
position_ids.masked_fill_(attention_mask == 0, 1)
520540
if past_key_values:
521-
position_ids = position_ids[:, -1].unsqueeze(-1)
541+
position_ids = position_ids[:, -input_ids.shape[1]:]
522542

523543
return {
524544
"input_ids": input_ids,
@@ -594,10 +614,6 @@ def _from_pretrained(
594614
model_type = config.model_type.replace("_", "-")
595615
if model_type == "bloom":
596616
init_cls = OVBloomForCausalLM
597-
elif model_type == "mpt":
598-
init_cls = OVMPTForCausalLM
599-
elif model_type == "opt":
600-
init_cls = OVOPTForCausalLM
601617
elif model_type == "gpt-bigcode":
602618
init_cls = OVGPTBigCodeForCausalLM
603619
else:
@@ -651,22 +667,13 @@ def _from_pretrained(
651667
class OVBloomForCausalLM(OVModelForCausalLM):
652668
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
653669
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
654-
attention_mask = kwargs.get("attention_mask", None)
655-
use_cache = kwargs.get("use_cache", None)
656-
657670
# only last token for input_ids if past is not None
658671
if past_key_values and not self.stateful:
659672
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
660673
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
661674
past_key_values = self._convert_to_bloom_cache(past_key_values)
662-
663-
return {
664-
"input_ids": input_ids,
665-
"past_key_values": past_key_values,
666-
"use_cache": use_cache,
667-
"position_ids": None,
668-
"attention_mask": attention_mask,
669-
}
675+
676+
return super().prepare_inputs_for_generation(self, input_ids, past_key_values=past_key_values, **kwargs)
670677

671678
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
672679
def _reorder_cache(
@@ -733,36 +740,6 @@ def _convert_to_standard_cache(
733740
)
734741

735742

736-
class OVOPTForCausalLM(OVModelForCausalLM):
737-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
738-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
739-
attention_mask = kwargs.get("attention_mask", None)
740-
use_cache = kwargs.get("use_cache", None)
741-
742-
return {
743-
"input_ids": input_ids,
744-
"past_key_values": past_key_values,
745-
"use_cache": use_cache,
746-
"position_ids": None,
747-
"attention_mask": attention_mask,
748-
}
749-
750-
751-
class OVMPTForCausalLM(OVModelForCausalLM):
752-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
753-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
754-
attention_mask = kwargs.get("attention_mask", None)
755-
use_cache = kwargs.get("use_cache", None)
756-
757-
return {
758-
"input_ids": input_ids,
759-
"past_key_values": past_key_values,
760-
"use_cache": use_cache,
761-
"position_ids": None,
762-
"attention_mask": attention_mask,
763-
}
764-
765-
766743
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
767744
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
768745
def _reorder_cache(

0 commit comments

Comments
 (0)