Skip to content

Commit 2b3f550

Browse files
authored
Fix performance issue for Qwen dynamic causal mask (#651)
* [Qwen]Fix performance issue with dynamic causal mask * [Qwen] fix code style
1 parent e7108de commit 2b3f550

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

optimum/exporters/openvino/model_patcher.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,10 @@ def _qwen_attention_forward(
531531
value = value.permute(0, 2, 1, 3)
532532

533533
if not self.use_cache_quantization and SUPPORT_SDPA:
534-
causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
534+
# For performance, using constant tril to generate causal_mask
535+
causal_mask = self.bias[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
535536
if attention_mask is not None:
536-
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill(
537+
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1).masked_fill(
537538
~causal_mask, torch.finfo(query.dtype).min
538539
)
539540
else:
@@ -578,8 +579,17 @@ def __init__(
578579

579580
def __enter__(self):
580581
super().__enter__()
582+
max_positions = self._model.config.seq_length
581583
for block in self._model.transformer.h:
582584
block.attn._orig_forward = block.attn.forward
585+
# For performance, using constant tril to generate causal_mask
586+
block.attn.register_buffer(
587+
"bias",
588+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
589+
1, 1, max_positions, max_positions
590+
),
591+
persistent=False,
592+
)
583593
block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn)
584594

585595
def __exit__(self, exc_type, exc_value, traceback):

0 commit comments

Comments
 (0)