Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama and gemma modeling patching for openvino export #714

Merged
merged 15 commits into from
May 23, 2024
104 changes: 103 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_value, traceback):
# adopted from
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
Expand All @@ -306,10 +306,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po

dtype, device = input_tensor.dtype, input_tensor.device

# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min
sequence_length = input_tensor.shape[1]
# difference with original modeling
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
Expand All @@ -321,7 +323,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po

target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length

# difference with original modeling
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype

if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
Expand Down Expand Up @@ -358,6 +362,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
return causal_mask


# adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
def _llama_gemma_update_causal_mask_latest(
self,
attention_mask,
input_tensor,
cache_position,
past_key_values,
output_attentions,
):
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
# difference with original modeling
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
# during execution on platforms with default lower precision (bfloat16, float16)
min_dtype = torch.finfo(torch.float16).min

sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
# difference with original modeling
causal_mask = (
torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
)

if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


# TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
if is_transformers_version(">", "4.40.2"):
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
else:
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy


class GemmaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
Expand Down
6 changes: 3 additions & 3 deletions optimum/intel/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,17 +906,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_path = os.path.join(output_dir, OV_XML_FILE_NAME)
self.compression_controller.prepare_for_export()
model_type = self.model.config.model_type.replace("_", "-")
onnx_config_class = TasksManager.get_exporter_config_constructor(
exporter_config_class = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model=self.model,
task=self.task,
model_type=model_type,
)

if self.task == "text-generation":
onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache)
onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache)
else:
onnx_config = onnx_config_class(self.model.config)
onnx_config = exporter_config_class(self.model.config)

num_parameters = self.model.num_parameters()
save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model
Expand Down
Loading