@@ -357,7 +357,6 @@ def prepare_inputs(
357
357
position_ids : Optional [torch .LongTensor ] = None ,
358
358
** kwargs ,
359
359
) -> Dict :
360
-
361
360
batch_size = input_ids .shape [0 ]
362
361
if self .config .model_type == "bloom" :
363
362
batch_size *= self .config .num_attention_heads
@@ -509,15 +508,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
509
508
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
510
509
# input_ids based on the past_length.
511
510
elif self .past_len < input_ids .shape [1 ]:
512
- input_ids = input_ids [:, self .past_len :]
511
+ input_ids = input_ids [:, self .past_len :]
513
512
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
514
513
position_ids = kwargs .get ("position_ids" , None )
515
514
if attention_mask is not None and position_ids is None and "position_ids" in self .input_names :
516
515
# create position_ids on the fly for batch generation
517
516
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
518
517
position_ids .masked_fill_ (attention_mask == 0 , 1 )
519
518
if past_key_values :
520
- position_ids = position_ids [:, - input_ids .shape [1 ]:]
519
+ position_ids = position_ids [:, - input_ids .shape [1 ] :]
521
520
522
521
return {
523
522
"input_ids" : input_ids ,
@@ -651,8 +650,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
651
650
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
652
651
if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
653
652
past_key_values = self ._convert_to_bloom_cache (past_key_values )
654
-
655
- return super ().prepare_inputs_for_generation (self , input_ids , past_key_values = past_key_values , ** kwargs )
653
+ return super ().prepare_inputs_for_generation (input_ids , past_key_values = past_key_values , ** kwargs )
656
654
657
655
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
658
656
def _reorder_cache (
0 commit comments