@@ -291,11 +291,9 @@ def __exit__(self, exc_type, exc_value, traceback):
291
291
# adopted from
292
292
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293
293
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294
- def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , ** kwargs ):
294
+ def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
295
295
from transformers .modeling_attn_mask_utils import AttentionMaskConverter
296
296
297
- # for compatibility with https://github.com/huggingface/transformers/pull/30047
298
- current_length = kwargs .get ("current_length" , cache_position [- 1 ])
299
297
dtype , device = input_tensor .dtype , input_tensor .device
300
298
301
299
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
@@ -305,7 +303,14 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
305
303
if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
306
304
target_length = self .config .max_position_embeddings
307
305
else : # dynamic cache
308
- target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length + 1
306
+ if past_seen_tokens is not None :
307
+ current_length = past_seen_tokens + sequence_length + 1
308
+ # TODO : remove after support of transformers >= v4.40.0
309
+ else :
310
+ current_length = cache_position [- 1 ] + 1
311
+
312
+ target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length
313
+
309
314
310
315
causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
311
316
if sequence_length != 1 :
0 commit comments