Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b69c3cd

Browse files
committedMar 29, 2024
rework prepare_inputs_for_generation for OVModelForCausalLM
1 parent 6d0e334 commit b69c3cd

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed
 

‎optimum/intel/openvino/modeling_decoder.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ def prepare_inputs(
357357
position_ids: Optional[torch.LongTensor] = None,
358358
**kwargs,
359359
) -> Dict:
360-
361360
batch_size = input_ids.shape[0]
362361
if self.config.model_type == "bloom":
363362
batch_size *= self.config.num_attention_heads
@@ -509,15 +508,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
509508
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
510509
# input_ids based on the past_length.
511510
elif self.past_len < input_ids.shape[1]:
512-
input_ids = input_ids[:, self.past_len:]
511+
input_ids = input_ids[:, self.past_len :]
513512
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
514513
position_ids = kwargs.get("position_ids", None)
515514
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
516515
# create position_ids on the fly for batch generation
517516
position_ids = attention_mask.long().cumsum(-1) - 1
518517
position_ids.masked_fill_(attention_mask == 0, 1)
519518
if past_key_values:
520-
position_ids = position_ids[:, -input_ids.shape[1]:]
519+
position_ids = position_ids[:, -input_ids.shape[1] :]
521520

522521
return {
523522
"input_ids": input_ids,
@@ -651,8 +650,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
651650
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
652651
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
653652
past_key_values = self._convert_to_bloom_cache(past_key_values)
654-
655-
return super().prepare_inputs_for_generation(self, input_ids, past_key_values=past_key_values, **kwargs)
653+
return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)
656654

657655
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
658656
def _reorder_cache(

0 commit comments

Comments
 (0)
Please sign in to comment.