@@ -378,7 +378,6 @@ def prepare_inputs(
378
378
position_ids : Optional [torch .LongTensor ] = None ,
379
379
** kwargs ,
380
380
) -> Dict :
381
-
382
381
batch_size = input_ids .shape [0 ]
383
382
if self .config .model_type == "bloom" :
384
383
batch_size *= self .config .num_attention_heads
@@ -530,15 +529,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
530
529
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
531
530
# input_ids based on the past_length.
532
531
elif self .past_len < input_ids .shape [1 ]:
533
- input_ids = input_ids [:, self .past_len :]
532
+ input_ids = input_ids [:, self .past_len :]
534
533
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
535
534
position_ids = kwargs .get ("position_ids" , None )
536
535
if attention_mask is not None and position_ids is None and "position_ids" in self .input_names :
537
536
# create position_ids on the fly for batch generation
538
537
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
539
538
position_ids .masked_fill_ (attention_mask == 0 , 1 )
540
539
if past_key_values :
541
- position_ids = position_ids [:, - input_ids .shape [1 ]:]
540
+ position_ids = position_ids [:, - input_ids .shape [1 ] :]
542
541
543
542
return {
544
543
"input_ids" : input_ids ,
@@ -672,7 +671,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
672
671
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
673
672
if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
674
673
past_key_values = self ._convert_to_bloom_cache (past_key_values )
675
-
674
+
676
675
return super ().prepare_inputs_for_generation (self , input_ids , past_key_values = past_key_values , ** kwargs )
677
676
678
677
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
0 commit comments