Skip to content

Commit 48e72ef

Browse files
eaidovaecharlaix
andauthored
Update min dtype in falcon for prevent bf16 execution issue (#1093)
* update min dtype in falcon for prevent bf16 execution issue * Update optimum/exporters/openvino/model_patcher.py * typo --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 5fa9602 commit 48e72ef

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

optimum/exporters/openvino/model_patcher.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -2501,6 +2501,40 @@ def __enter__(self):
25012501
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
25022502

25032503

2504+
# Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138
2505+
def _falcon_prepare_4d_causal_attention_mask_with_cache_position(
2506+
attention_mask: torch.Tensor,
2507+
sequence_length: int,
2508+
target_length: int,
2509+
dtype: torch.dtype,
2510+
device: torch.device,
2511+
cache_position: torch.Tensor,
2512+
batch_size: int,
2513+
**kwargs,
2514+
):
2515+
if attention_mask is not None and attention_mask.dim() == 4:
2516+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2517+
causal_mask = attention_mask
2518+
else:
2519+
# different from original: allow to provide min_dtype as parameter
2520+
min_dtype = torch.finfo(dtype).min if "min_dtype" not in kwargs else kwargs["min_dtype"]
2521+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
2522+
if sequence_length != 1:
2523+
causal_mask = torch.triu(causal_mask, diagonal=1)
2524+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2525+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2526+
if attention_mask is not None:
2527+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2528+
mask_length = attention_mask.shape[-1]
2529+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
2530+
padding_mask = padding_mask == 0
2531+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2532+
padding_mask, min_dtype
2533+
)
2534+
2535+
return causal_mask
2536+
2537+
25042538
def _falcon_update_causal_mask(
25052539
self,
25062540
attention_mask: torch.Tensor,
@@ -2520,13 +2554,6 @@ def _falcon_update_causal_mask(
25202554
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
25212555
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
25222556

2523-
if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"):
2524-
_prepare_4d_causal_attention_mask_with_cache_position = (
2525-
self._prepare_4d_causal_attention_mask_with_cache_position
2526-
)
2527-
else:
2528-
from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position
2529-
25302557
if self.config._attn_implementation == "flash_attention_2":
25312558
if attention_mask is not None and 0.0 in attention_mask:
25322559
return attention_mask
@@ -2568,7 +2595,7 @@ def _falcon_update_causal_mask(
25682595
)
25692596

25702597
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2571-
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
2598+
causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position(
25722599
attention_mask,
25732600
sequence_length=sequence_length,
25742601
target_length=target_length,

0 commit comments

Comments
 (0)