Skip to content

Commit 3841443

Browse files
committed
rework prepare_inputs_for_generation for OVModelForCausalLM
1 parent d375995 commit 3841443

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

optimum/intel/openvino/modeling_decoder.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,6 @@ def prepare_inputs(
378378
position_ids: Optional[torch.LongTensor] = None,
379379
**kwargs,
380380
) -> Dict:
381-
382381
batch_size = input_ids.shape[0]
383382
if self.config.model_type == "bloom":
384383
batch_size *= self.config.num_attention_heads
@@ -530,15 +529,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
530529
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
531530
# input_ids based on the past_length.
532531
elif self.past_len < input_ids.shape[1]:
533-
input_ids = input_ids[:, self.past_len:]
532+
input_ids = input_ids[:, self.past_len :]
534533
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
535534
position_ids = kwargs.get("position_ids", None)
536535
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
537536
# create position_ids on the fly for batch generation
538537
position_ids = attention_mask.long().cumsum(-1) - 1
539538
position_ids.masked_fill_(attention_mask == 0, 1)
540539
if past_key_values:
541-
position_ids = position_ids[:, -input_ids.shape[1]:]
540+
position_ids = position_ids[:, -input_ids.shape[1] :]
542541

543542
return {
544543
"input_ids": input_ids,
@@ -672,7 +671,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
672671
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
673672
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
674673
past_key_values = self._convert_to_bloom_cache(past_key_values)
675-
674+
676675
return super().prepare_inputs_for_generation(self, input_ids, past_key_values=past_key_values, **kwargs)
677676

678677
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache

0 commit comments

Comments
 (0)