Skip to content

Commit f70a926

Browse files
committed
refactoring
1 parent b69c3cd commit f70a926

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

optimum/intel/openvino/modeling_decoder.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(
120120
self._original_model = self.model.clone() # keep original model for serialization
121121
self._pkv_precision = Type.f32
122122
self.next_beam_idx = None
123-
self.past_len = 0
123+
self._past_length = 0
124124
self.update_pkv_precision()
125125
if self.is_dynamic:
126126
self.model = self._reshape(self.model, -1, -1)
@@ -365,12 +365,6 @@ def prepare_inputs(
365365
if not self.stateful:
366366
if past_key_values is not None:
367367
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
368-
seq_len_dim = -2
369-
if self.config.model_type == "chatglm":
370-
seq_len_dim = 0
371-
elif self.config.model_type == "qwen":
372-
seq_len_dim = 1
373-
self.past_len = past_key_values[0][1].shape[seq_len_dim]
374368
if self._pkv_precision == Type.bf16:
375369
# numpy does not support bf16, pretending f16, should change to bf16
376370
past_key_values = tuple(
@@ -383,15 +377,13 @@ def prepare_inputs(
383377
past_key_values = tuple(
384378
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
385379
)
386-
else:
387-
self.past_len = past_key_values[0].shape[-2]
388380

389381
# Add the past_key_values to the decoder inputs
390382
inputs = dict(zip(self.key_value_input_names, past_key_values))
391383

392384
# Create empty past_key_values for decoder_with_past first generation step
393385
elif self.use_cache:
394-
self.past_len = 0
386+
past_len = 0
395387
for input_name in self.key_value_input_names:
396388
model_inputs = self.model.input(input_name)
397389
shape = model_inputs.get_partial_shape()
@@ -414,7 +406,8 @@ def prepare_inputs(
414406
# Set initial value for the next beam_idx input that will be used at the current iteration
415407
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
416408
self.next_beam_idx = np.arange(batch_size, dtype=int)
417-
self.past_len = 0
409+
self._past_length = 0
410+
past_len = self._get_past_length(past_key_values)
418411

419412
inputs["input_ids"] = np.array(input_ids)
420413
# Add the attention_mask inputs when needed
@@ -423,7 +416,7 @@ def prepare_inputs(
423416
attention_mask = np.array(attention_mask)
424417
else:
425418
attention_mask = np.ones(
426-
(input_ids.shape[0], input_ids.shape[1] + self.past_len), dtype=inputs["input_ids"].dtype
419+
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
427420
)
428421

429422
if "attention_mask" in self.input_names:
@@ -436,7 +429,7 @@ def prepare_inputs(
436429
position_ids = np.cumsum(attention_mask, axis=1) - 1
437430
position_ids[attention_mask == 0] = 1
438431
if past_key_values:
439-
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
432+
position_ids = np.expand_dims(position_ids[:, -input_ids.shape[1] :], axis=-1)
440433

441434
inputs["position_ids"] = position_ids
442435

@@ -474,7 +467,7 @@ def forward(
474467
# the first condition at the function beginning above.
475468
# It should be something that is not None and it should be True when converted to Boolean.
476469
past_key_values = ((),)
477-
self.past_len += input_ids.shape[1]
470+
self._past_length += input_ids.shape[1]
478471

479472
if not self.stateful:
480473
if self.use_cache:
@@ -485,10 +478,8 @@ def forward(
485478
past_key_values = tuple(
486479
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
487480
)
488-
self.past_len += input_ids.shape[1]
489481
else:
490482
past_key_values = None
491-
self.past_len = 0
492483

493484
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
494485

@@ -499,16 +490,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
499490
use_cache = kwargs.get("use_cache", None)
500491

501492
if past_key_values is not None:
493+
past_len = self._get_past_length(past_key_values)
502494
# Keep only the unprocessed tokens:
503495
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
504496
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
505497
# input)
506498
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
507-
input_ids = input_ids[:, -(attention_mask.shape[1] - self.past_len) :]
499+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
508500
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
509501
# input_ids based on the past_length.
510-
elif self.past_len < input_ids.shape[1]:
511-
input_ids = input_ids[:, self.past_len :]
502+
elif past_len < input_ids.shape[1]:
503+
input_ids = input_ids[:, past_len:]
512504
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
513505
position_ids = kwargs.get("position_ids", None)
514506
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
@@ -526,6 +518,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
526518
"attention_mask": attention_mask,
527519
}
528520

521+
def _get_past_length(self, past_key_values=None):
522+
if past_key_values is None:
523+
return 0
524+
if self.stateful:
525+
return self._past_length
526+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
527+
return past_key_values[0].shape[-2]
528+
seq_length_dim = -2
529+
if self.config.model_type == "chatglm":
530+
seq_length_dim = 0
531+
elif self.config.model_type == "qwen":
532+
seq_length_dim = 1
533+
# input is tuple of pairs
534+
if isinstance(past_key_values[0], (tuple, list)):
535+
return past_key_values[0][1].shape[seq_length_dim]
536+
# past key values comes after flattening
537+
return past_key_values[1].shape[seq_length_dim]
538+
529539
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
530540
def _reorder_cache(
531541
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor

0 commit comments

Comments
 (0)