Skip to content

Commit 431db93

Browse files
committed
fix
1 parent 6cd92ac commit 431db93

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

optimum/exporters/openvino/model_patcher.py

+8
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,14 @@ def __exit__(self, exc_type, exc_value, traceback):
294294
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
295295
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
296296

297+
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
298+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
299+
# in order to dispatch on Flash Attention 2.
300+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
301+
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
302+
):
303+
return None
304+
297305
dtype, device = input_tensor.dtype, input_tensor.device
298306

299307
# using minimum from dtype with larger bandwith (floa32) may lead to overflow

0 commit comments

Comments
 (0)