Skip to content

Commit 91e635e

Browse files
authored
Fix causal mask update bf16 accuracy issue in gemma (#654)
* fix causal mask update bf16 accuracy issue in gemma * update llama config
1 parent 9724919 commit 91e635e

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

optimum/exporters/openvino/model_configs.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
22+
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
@@ -34,6 +34,7 @@
3434
BaichuanModelPatcher,
3535
ChatGLMModelPatcher,
3636
GemmaModelPatcher,
37+
LlamaModelPatcher,
3738
MixtralModelPatcher,
3839
QwenModelPatcher,
3940
)
@@ -274,6 +275,24 @@ def patch_model_for_export(
274275
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)
275276

276277

278+
@register_in_tasks_manager(
279+
"llama",
280+
*[
281+
"feature-extraction",
282+
"feature-extraction-with-past",
283+
"text-generation",
284+
"text-generation-with-past",
285+
"text-classification",
286+
],
287+
library_name="transformers",
288+
)
289+
class LlamaOpenVINOConfig(LlamaOnnxConfig):
290+
def patch_model_for_export(
291+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
292+
) -> "ModelPatcher":
293+
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
294+
295+
277296
class QwenDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
278297
def __init__(
279298
self,

optimum/exporters/openvino/model_patcher.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,74 @@ def __exit__(self, exc_type, exc_value, traceback):
288288
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward
289289

290290

291+
# adopted from
292+
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293+
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294+
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, **kwargs):
295+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
296+
297+
# for compatibility with https://github.com/huggingface/transformers/pull/30047
298+
current_length = kwargs.get("current_length", cache_position[-1])
299+
dtype, device = input_tensor.dtype, input_tensor.device
300+
301+
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
302+
# during execution on platforms with default lower precision (bfloat16, float16)
303+
min_dtype = torch.finfo(torch.float16).min
304+
sequence_length = input_tensor.shape[1]
305+
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
306+
target_length = self.config.max_position_embeddings
307+
else: # dynamic cache
308+
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
309+
310+
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
311+
if sequence_length != 1:
312+
causal_mask = torch.triu(causal_mask, diagonal=1)
313+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
314+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
315+
if attention_mask is not None:
316+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
317+
if attention_mask.dim() == 2:
318+
mask_length = attention_mask.shape[-1]
319+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
320+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
321+
elif attention_mask.dim() == 4:
322+
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
323+
# cache. In that case, the 4D attention mask attends to the newest tokens only.
324+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
325+
offset = cache_position[0]
326+
else:
327+
offset = 0
328+
mask_shape = attention_mask.shape
329+
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
330+
causal_mask[
331+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
332+
] = mask_slice
333+
334+
if (
335+
self.config._attn_implementation == "sdpa"
336+
and attention_mask is not None
337+
and attention_mask.device.type == "cuda"
338+
):
339+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
340+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
341+
# Details: https://github.com/pytorch/pytorch/issues/110213
342+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
343+
344+
return causal_mask
345+
346+
291347
class GemmaModelPatcher(DecoderModelPatcher):
292348
def __enter__(self):
293349
super().__enter__()
294350

351+
# gemma has some accuracy issues with bf16 with transformers >= 4.39
352+
# fill causal mask in slightly different way for avoid overflow on some platforms
353+
if is_transformers_version(">=", "4.39.0"):
354+
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
355+
self._model.model._update_causal_mask = types.MethodType(
356+
_llama_gemma_update_causal_mask, self._model.model
357+
)
358+
295359
# init inv_freq for torchscript tracing
296360
# https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108
297361
for layer in self._model.model.layers:
@@ -301,6 +365,29 @@ def __enter__(self):
301365
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
302366
)
303367

368+
def __exit__(self, exc_type, exc_value, traceback):
369+
super().__exit__(exc_type, exc_value, traceback)
370+
if hasattr(self._model.model, "_orig_update_causal_mask"):
371+
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
372+
373+
374+
class LlamaModelPatcher(DecoderModelPatcher):
375+
def __enter__(self):
376+
super().__enter__()
377+
378+
# llama has some accuracy issues with bf16 with transformers >= 4.39
379+
# fill causal mask in slightly different way for avoid overflow on some platforms
380+
if is_transformers_version(">=", "4.39.0"):
381+
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
382+
self._model.model._update_causal_mask = types.MethodType(
383+
_llama_gemma_update_causal_mask, self._model.model
384+
)
385+
386+
def __exit__(self, exc_type, exc_value, traceback):
387+
super().__exit__(exc_type, exc_value, traceback)
388+
if hasattr(self._model.model, "_orig_update_causal_mask"):
389+
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
390+
304391

305392
SUPPORT_SDPA = is_torch_version(">", "2.1.0")
306393

@@ -465,7 +552,6 @@ def _qwen_attention_forward(
465552
raise ValueError("Cannot output attentions while using flash-attn")
466553
else:
467554
outputs += (attn_weight,)
468-
469555
return outputs
470556

471557

0 commit comments

Comments
 (0)