Skip to content

Commit 08e3c3d

Browse files
committed
Fix compatibility for transformers v4.41.0 llama and gemma modeling patching
1 parent 2b902bb commit 08e3c3d

File tree

2 files changed

+104
-2
lines changed

2 files changed

+104
-2
lines changed

optimum/exporters/openvino/model_patcher.py

+103-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_value, traceback):
293293
# adopted from
294294
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
295295
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
296-
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
296+
def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
297297
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
298298

299299
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
@@ -306,10 +306,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
306306

307307
dtype, device = input_tensor.dtype, input_tensor.device
308308

309+
# difference with original modeling
309310
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
310311
# during execution on platforms with default lower precision (bfloat16, float16)
311312
min_dtype = torch.finfo(torch.float16).min
312313
sequence_length = input_tensor.shape[1]
314+
# difference with original modeling
313315
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
314316
target_length = self.config.max_position_embeddings
315317
else: # dynamic cache
@@ -321,7 +323,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
321323

322324
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
323325

326+
# difference with original modeling
324327
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
328+
325329
if sequence_length != 1:
326330
causal_mask = torch.triu(causal_mask, diagonal=1)
327331
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
@@ -358,6 +362,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
358362
return causal_mask
359363

360364

365+
# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
366+
def _llama_gemma_update_causal_mask_latest(
367+
self,
368+
attention_mask,
369+
input_tensor,
370+
cache_position,
371+
past_key_values,
372+
output_attentions,
373+
):
374+
from transformers.cache_utils import StaticCache
375+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
376+
377+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
378+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
379+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
380+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
381+
382+
if self.config._attn_implementation == "flash_attention_2":
383+
if attention_mask is not None and 0.0 in attention_mask:
384+
return attention_mask
385+
return None
386+
387+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
388+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
389+
# to infer the attention mask.
390+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
391+
using_static_cache = isinstance(past_key_values, StaticCache)
392+
393+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
394+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
395+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
396+
attention_mask,
397+
inputs_embeds=input_tensor,
398+
past_key_values_length=past_seen_tokens,
399+
is_training=self.training,
400+
):
401+
return None
402+
403+
dtype, device = input_tensor.dtype, input_tensor.device
404+
# difference with original modeling
405+
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
406+
# during execution on platforms with default lower precision (bfloat16, float16)
407+
min_dtype = torch.finfo(torch.float16).min
408+
409+
sequence_length = input_tensor.shape[1]
410+
if using_static_cache:
411+
target_length = past_key_values.get_max_length()
412+
else:
413+
target_length = (
414+
attention_mask.shape[-1]
415+
if isinstance(attention_mask, torch.Tensor)
416+
else past_seen_tokens + sequence_length + 1
417+
)
418+
419+
if attention_mask is not None and attention_mask.dim() == 4:
420+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
421+
if attention_mask.max() != 0:
422+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
423+
causal_mask = attention_mask
424+
else:
425+
# difference with original modeling
426+
causal_mask = (
427+
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
428+
)
429+
430+
if sequence_length != 1:
431+
causal_mask = torch.triu(causal_mask, diagonal=1)
432+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
433+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
434+
if attention_mask is not None:
435+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
436+
mask_length = attention_mask.shape[-1]
437+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
438+
padding_mask = padding_mask == 0
439+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
440+
padding_mask, min_dtype
441+
)
442+
if (
443+
self.config._attn_implementation == "sdpa"
444+
and attention_mask is not None
445+
and attention_mask.device.type == "cuda"
446+
and not output_attentions
447+
):
448+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
449+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
450+
# Details: https://github.com/pytorch/pytorch/issues/110213
451+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
452+
453+
return causal_mask
454+
455+
456+
# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
457+
if is_transformers_version(">=", "4.41.0"):
458+
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
459+
else:
460+
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
461+
462+
361463
class GemmaModelPatcher(DecoderModelPatcher):
362464
def __enter__(self):
363465
super().__enter__()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
INSTALL_REQUIRE = [
3030
"torch>=1.11",
31-
"transformers>=4.36.0,<4.41.0",
31+
"transformers @ git+https://github.com/huggingface/transformers.git",
3232
"optimum~=1.19",
3333
"datasets>=1.4.0",
3434
"sentencepiece",

0 commit comments

Comments
 (0)