@@ -288,10 +288,74 @@ def __exit__(self, exc_type, exc_value, traceback):
288
288
block .self_attention .core_attention .forward = block .self_attention .core_attention ._orig_forward
289
289
290
290
291
+ # adopted from
292
+ # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293
+ # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294
+ def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , ** kwargs ):
295
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
296
+
297
+ # for compatibility with https://github.com/huggingface/transformers/pull/30047
298
+ current_length = kwargs .get ("current_length" , cache_position [- 1 ])
299
+ dtype , device = input_tensor .dtype , input_tensor .device
300
+
301
+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
302
+ # during execution on platforms with default lower precision (bfloat16, float16)
303
+ min_dtype = torch .finfo (torch .float16 ).min
304
+ sequence_length = input_tensor .shape [1 ]
305
+ if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
306
+ target_length = self .config .max_position_embeddings
307
+ else : # dynamic cache
308
+ target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length + 1
309
+
310
+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
311
+ if sequence_length != 1 :
312
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
313
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
314
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
315
+ if attention_mask is not None :
316
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
317
+ if attention_mask .dim () == 2 :
318
+ mask_length = attention_mask .shape [- 1 ]
319
+ padding_mask = causal_mask [..., :mask_length ].eq (0.0 ) * attention_mask [:, None , None , :].eq (0.0 )
320
+ causal_mask [..., :mask_length ] = causal_mask [..., :mask_length ].masked_fill (padding_mask , min_dtype )
321
+ elif attention_mask .dim () == 4 :
322
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
323
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
324
+ if attention_mask .shape [- 2 ] < cache_position [0 ] + sequence_length :
325
+ offset = cache_position [0 ]
326
+ else :
327
+ offset = 0
328
+ mask_shape = attention_mask .shape
329
+ mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
330
+ causal_mask [
331
+ : mask_shape [0 ], : mask_shape [1 ], offset : mask_shape [2 ] + offset , : mask_shape [3 ]
332
+ ] = mask_slice
333
+
334
+ if (
335
+ self .config ._attn_implementation == "sdpa"
336
+ and attention_mask is not None
337
+ and attention_mask .device .type == "cuda"
338
+ ):
339
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
340
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
341
+ # Details: https://github.com/pytorch/pytorch/issues/110213
342
+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
343
+
344
+ return causal_mask
345
+
346
+
291
347
class GemmaModelPatcher (DecoderModelPatcher ):
292
348
def __enter__ (self ):
293
349
super ().__enter__ ()
294
350
351
+ # gemma has some accuracy issues with bf16 with transformers >= 4.39
352
+ # fill causal mask in slightly different way for avoid overflow on some platforms
353
+ if is_transformers_version (">=" , "4.39.0" ):
354
+ self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
355
+ self ._model .model ._update_causal_mask = types .MethodType (
356
+ _llama_gemma_update_causal_mask , self ._model .model
357
+ )
358
+
295
359
# init inv_freq for torchscript tracing
296
360
# https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108
297
361
for layer in self ._model .model .layers :
@@ -301,6 +365,29 @@ def __enter__(self):
301
365
rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
302
366
)
303
367
368
+ def __exit__ (self , exc_type , exc_value , traceback ):
369
+ super ().__exit__ (exc_type , exc_value , traceback )
370
+ if hasattr (self ._model .model , "_orig_update_causal_mask" ):
371
+ self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
372
+
373
+
374
+ class LlamaModelPatcher (DecoderModelPatcher ):
375
+ def __enter__ (self ):
376
+ super ().__enter__ ()
377
+
378
+ # llama has some accuracy issues with bf16 with transformers >= 4.39
379
+ # fill causal mask in slightly different way for avoid overflow on some platforms
380
+ if is_transformers_version (">=" , "4.39.0" ):
381
+ self ._model .model ._orig_update_causal_mask = self ._model .model ._update_causal_mask
382
+ self ._model .model ._update_causal_mask = types .MethodType (
383
+ _llama_gemma_update_causal_mask , self ._model .model
384
+ )
385
+
386
+ def __exit__ (self , exc_type , exc_value , traceback ):
387
+ super ().__exit__ (exc_type , exc_value , traceback )
388
+ if hasattr (self ._model .model , "_orig_update_causal_mask" ):
389
+ self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
390
+
304
391
305
392
SUPPORT_SDPA = is_torch_version (">" , "2.1.0" )
306
393
@@ -465,7 +552,6 @@ def _qwen_attention_forward(
465
552
raise ValueError ("Cannot output attentions while using flash-attn" )
466
553
else :
467
554
outputs += (attn_weight ,)
468
-
469
555
return outputs
470
556
471
557
0 commit comments