Skip to content

Commit 5b93036

Browse files
committed
set block size as a env parameter
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent c6d2d0f commit 5b93036

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

optimum/exporters/ipex/cache_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import os
12
from typing import List, Optional, Tuple
23

34
import torch
45
from intel_extension_for_pytorch.llm.modules import PagedAttention
56
from transformers import Cache, PretrainedConfig
67

78

8-
# May need to tune based on sequence length and different models but default to 16 currently.
9-
BLOCK_SIZE = 16
9+
# Recommend 16 on CPU and 64 on XPU.
10+
BLOCK_SIZE = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", 16))
1011

1112

1213
class IPEXPagedCache(Cache):

optimum/exporters/ipex/modeling_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -652,17 +652,19 @@ def attention_interface(
652652
is_causal=True,
653653
)
654654
self.use_sdpa = True
655-
elif self.has_flash_attn(query) and past_len == 0:
655+
elif self.has_flash_attn(query):
656656
attn_output = torch.empty_like(query)
657657
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
658+
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
659+
query_max_len = input_lens.max() if past_len == 0 else 1
658660
PagedAttention.flash_attn_varlen_func(
659661
attn_output,
660662
query.contiguous() if query.device.type == "xpu" else query,
661663
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
662664
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
665+
query_len_tensor,
663666
seq_len_tensor,
664-
seq_len_tensor,
665-
input_lens.max(),
667+
query_max_len,
666668
input_lens.max(),
667669
1.0 / math.sqrt(self.head_dim),
668670
True,

0 commit comments

Comments
 (0)