Skip to content

Commit 1319d7b

Browse files
authored
Fix llama and gemma modeling patching for openvino export (#714)
* Fix compatibility for transformers v4.41.0 llama and gemma modeling patching * fix for dev transformers version * update setup
1 parent c69fe32 commit 1319d7b

File tree

2 files changed

+106
-4
lines changed

2 files changed

+106
-4
lines changed

optimum/exporters/openvino/model_patcher.py

+103-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def __exit__(self, exc_type, exc_value, traceback):
301301
# adopted from
302302
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
303303
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
304-
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
304+
def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
305305
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
306306

307307
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
@@ -314,10 +314,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
314314

315315
dtype, device = input_tensor.dtype, input_tensor.device
316316

317+
# difference with original modeling
317318
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
318319
# during execution on platforms with default lower precision (bfloat16, float16)
319320
min_dtype = torch.finfo(torch.float16).min
320321
sequence_length = input_tensor.shape[1]
322+
# difference with original modeling
321323
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
322324
target_length = self.config.max_position_embeddings
323325
else: # dynamic cache
@@ -329,7 +331,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
329331

330332
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
331333

334+
# difference with original modeling
332335
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
336+
333337
if sequence_length != 1:
334338
causal_mask = torch.triu(causal_mask, diagonal=1)
335339
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
@@ -366,6 +370,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
366370
return causal_mask
367371

368372

373+
# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
374+
def _llama_gemma_update_causal_mask_latest(
375+
self,
376+
attention_mask,
377+
input_tensor,
378+
cache_position,
379+
past_key_values,
380+
output_attentions,
381+
):
382+
from transformers.cache_utils import StaticCache
383+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
384+
385+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
386+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
387+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
388+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
389+
390+
if self.config._attn_implementation == "flash_attention_2":
391+
if attention_mask is not None and 0.0 in attention_mask:
392+
return attention_mask
393+
return None
394+
395+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
396+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
397+
# to infer the attention mask.
398+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
399+
using_static_cache = isinstance(past_key_values, StaticCache)
400+
401+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
402+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
403+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
404+
attention_mask,
405+
inputs_embeds=input_tensor,
406+
past_key_values_length=past_seen_tokens,
407+
is_training=self.training,
408+
):
409+
return None
410+
411+
dtype, device = input_tensor.dtype, input_tensor.device
412+
# difference with original modeling
413+
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
414+
# during execution on platforms with default lower precision (bfloat16, float16)
415+
min_dtype = torch.finfo(torch.float16).min
416+
417+
sequence_length = input_tensor.shape[1]
418+
if using_static_cache:
419+
target_length = past_key_values.get_max_length()
420+
else:
421+
target_length = (
422+
attention_mask.shape[-1]
423+
if isinstance(attention_mask, torch.Tensor)
424+
else past_seen_tokens + sequence_length + 1
425+
)
426+
427+
if attention_mask is not None and attention_mask.dim() == 4:
428+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
429+
if attention_mask.max() != 0:
430+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
431+
causal_mask = attention_mask
432+
else:
433+
# difference with original modeling
434+
causal_mask = (
435+
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
436+
)
437+
438+
if sequence_length != 1:
439+
causal_mask = torch.triu(causal_mask, diagonal=1)
440+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
441+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
442+
if attention_mask is not None:
443+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
444+
mask_length = attention_mask.shape[-1]
445+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
446+
padding_mask = padding_mask == 0
447+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
448+
padding_mask, min_dtype
449+
)
450+
if (
451+
self.config._attn_implementation == "sdpa"
452+
and attention_mask is not None
453+
and attention_mask.device.type == "cuda"
454+
and not output_attentions
455+
):
456+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
457+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
458+
# Details: https://github.com/pytorch/pytorch/issues/110213
459+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
460+
461+
return causal_mask
462+
463+
464+
# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
465+
if is_transformers_version(">", "4.40.2"):
466+
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
467+
else:
468+
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
469+
470+
369471
class GemmaModelPatcher(DecoderModelPatcher):
370472
def __enter__(self):
371473
super().__enter__()

optimum/intel/openvino/trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -906,17 +906,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
906906
output_path = os.path.join(output_dir, OV_XML_FILE_NAME)
907907
self.compression_controller.prepare_for_export()
908908
model_type = self.model.config.model_type.replace("_", "-")
909-
onnx_config_class = TasksManager.get_exporter_config_constructor(
909+
exporter_config_class = TasksManager.get_exporter_config_constructor(
910910
exporter="onnx",
911911
model=self.model,
912912
task=self.task,
913913
model_type=model_type,
914914
)
915915

916916
if self.task == "text-generation":
917-
onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache)
917+
onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache)
918918
else:
919-
onnx_config = onnx_config_class(self.model.config)
919+
onnx_config = exporter_config_class(self.model.config)
920920

921921
num_parameters = self.model.num_parameters()
922922
save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model

0 commit comments

Comments
 (0)