Skip to content

Commit e75b45b

Browse files
committed
Merge branch 'block_size' into qwen
2 parents c86fd1c + 31accd2 commit e75b45b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

optimum/exporters/ipex/cache_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1+
import os
12
from typing import List, Optional, Tuple
23

34
import torch
45
from intel_extension_for_pytorch.llm.modules import PagedAttention
56
from transformers import Cache, PretrainedConfig
67

78

8-
# May need to tune based on sequence length and different models but default to 16 currently.
9-
BLOCK_SIZE = 16
10-
11-
129
class IPEXPagedCache(Cache):
1310
"""
1411
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
@@ -48,7 +45,8 @@ def __init__(
4845
self.batch_size = batch_size
4946
# Used in `generate` to keep tally of how many tokens the cache has seen
5047
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
51-
self.block_size = BLOCK_SIZE
48+
default_block_size = 16 if device.type == "cpu" else 64
49+
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
5250
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
5351
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5452
batch_size, -1

optimum/exporters/ipex/modeling_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -773,17 +773,19 @@ def attention_interface(
773773
is_causal=True,
774774
)
775775
self.use_sdpa = True
776-
elif self.has_flash_attn(query) and past_len == 0:
776+
elif self.has_flash_attn(query):
777777
attn_output = torch.empty_like(query)
778778
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
779+
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
780+
query_max_len = input_lens.max() if past_len == 0 else 1
779781
PagedAttention.flash_attn_varlen_func(
780782
attn_output,
781783
query.contiguous() if query.device.type == "xpu" else query,
782784
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
783785
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
786+
query_len_tensor,
784787
seq_len_tensor,
785-
seq_len_tensor,
786-
input_lens.max(),
788+
query_max_len,
787789
input_lens.max(),
788790
1.0 / math.sqrt(self.head_dim),
789791
True,

0 commit comments

Comments
 (0)