Skip to content

Commit 8039ae1

Browse files
committed
refactoring
1 parent 9a70c66 commit 8039ae1

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

optimum/intel/openvino/modeling_decoder.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self._original_model = self.model.clone() # keep original model for serialization
137137
self._pkv_precision = Type.f32
138138
self.next_beam_idx = None
139-
self.past_len = 0
139+
self._past_length = 0
140140
self.update_pkv_precision()
141141
if self.is_dynamic:
142142
self.model = self._reshape(self.model, -1, -1)
@@ -386,12 +386,6 @@ def prepare_inputs(
386386
if not self.stateful:
387387
if past_key_values is not None:
388388
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
389-
seq_len_dim = -2
390-
if self.config.model_type == "chatglm":
391-
seq_len_dim = 0
392-
elif self.config.model_type == "qwen":
393-
seq_len_dim = 1
394-
self.past_len = past_key_values[0][1].shape[seq_len_dim]
395389
if self._pkv_precision == Type.bf16:
396390
# numpy does not support bf16, pretending f16, should change to bf16
397391
past_key_values = tuple(
@@ -404,15 +398,13 @@ def prepare_inputs(
404398
past_key_values = tuple(
405399
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
406400
)
407-
else:
408-
self.past_len = past_key_values[0].shape[-2]
409401

410402
# Add the past_key_values to the decoder inputs
411403
inputs = dict(zip(self.key_value_input_names, past_key_values))
412404

413405
# Create empty past_key_values for decoder_with_past first generation step
414406
elif self.use_cache:
415-
self.past_len = 0
407+
past_len = 0
416408
for input_name in self.key_value_input_names:
417409
model_inputs = self.model.input(input_name)
418410
shape = model_inputs.get_partial_shape()
@@ -435,7 +427,8 @@ def prepare_inputs(
435427
# Set initial value for the next beam_idx input that will be used at the current iteration
436428
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
437429
self.next_beam_idx = np.arange(batch_size, dtype=int)
438-
self.past_len = 0
430+
self._past_length = 0
431+
past_len = self._get_past_length(past_key_values)
439432

440433
inputs["input_ids"] = np.array(input_ids)
441434
# Add the attention_mask inputs when needed
@@ -444,7 +437,7 @@ def prepare_inputs(
444437
attention_mask = np.array(attention_mask)
445438
else:
446439
attention_mask = np.ones(
447-
(input_ids.shape[0], input_ids.shape[1] + self.past_len), dtype=inputs["input_ids"].dtype
440+
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
448441
)
449442

450443
if "attention_mask" in self.input_names:
@@ -457,7 +450,7 @@ def prepare_inputs(
457450
position_ids = np.cumsum(attention_mask, axis=1) - 1
458451
position_ids[attention_mask == 0] = 1
459452
if past_key_values:
460-
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)
453+
position_ids = np.expand_dims(position_ids[:, -input_ids.shape[1] :], axis=-1)
461454

462455
inputs["position_ids"] = position_ids
463456

@@ -495,7 +488,7 @@ def forward(
495488
# the first condition at the function beginning above.
496489
# It should be something that is not None and it should be True when converted to Boolean.
497490
past_key_values = ((),)
498-
self.past_len += input_ids.shape[1]
491+
self._past_length += input_ids.shape[1]
499492

500493
if not self.stateful:
501494
if self.use_cache:
@@ -506,10 +499,8 @@ def forward(
506499
past_key_values = tuple(
507500
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
508501
)
509-
self.past_len += input_ids.shape[1]
510502
else:
511503
past_key_values = None
512-
self.past_len = 0
513504

514505
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
515506

@@ -520,16 +511,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
520511
use_cache = kwargs.get("use_cache", None)
521512

522513
if past_key_values is not None:
514+
past_len = self._get_past_length(past_key_values)
523515
# Keep only the unprocessed tokens:
524516
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
525517
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
526518
# input)
527519
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
528-
input_ids = input_ids[:, -(attention_mask.shape[1] - self.past_len) :]
520+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
529521
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
530522
# input_ids based on the past_length.
531-
elif self.past_len < input_ids.shape[1]:
532-
input_ids = input_ids[:, self.past_len :]
523+
elif past_len < input_ids.shape[1]:
524+
input_ids = input_ids[:, past_len:]
533525
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
534526
position_ids = kwargs.get("position_ids", None)
535527
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
@@ -547,6 +539,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
547539
"attention_mask": attention_mask,
548540
}
549541

542+
def _get_past_length(self, past_key_values=None):
543+
if past_key_values is None:
544+
return 0
545+
if self.stateful:
546+
return self._past_length
547+
if self.config.model_type in MULTI_QUERY_ATTN_MODELS:
548+
return past_key_values[0].shape[-2]
549+
seq_length_dim = -2
550+
if self.config.model_type == "chatglm":
551+
seq_length_dim = 0
552+
elif self.config.model_type == "qwen":
553+
seq_length_dim = 1
554+
# input is tuple of pairs
555+
if isinstance(past_key_values[0], (tuple, list)):
556+
return past_key_values[0][1].shape[seq_length_dim]
557+
# past key values comes after flattening
558+
return past_key_values[1].shape[seq_length_dim]
559+
550560
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
551561
def _reorder_cache(
552562
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor

0 commit comments

Comments
 (0)