File tree 2 files changed +8
-8
lines changed
2 files changed +8
-8
lines changed Original file line number Diff line number Diff line change
1
+ import os
1
2
from typing import List , Optional , Tuple
2
3
3
4
import torch
4
5
from intel_extension_for_pytorch .llm .modules import PagedAttention
5
6
from transformers import Cache , PretrainedConfig
6
7
7
8
8
- # May need to tune based on sequence length and different models but default to 16 currently.
9
- BLOCK_SIZE = 16
10
-
11
-
12
9
class IPEXPagedCache (Cache ):
13
10
"""
14
11
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
@@ -48,7 +45,8 @@ def __init__(
48
45
self .batch_size = batch_size
49
46
# Used in `generate` to keep tally of how many tokens the cache has seen
50
47
self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
51
- self .block_size = BLOCK_SIZE
48
+ default_block_size = 16 if device .type == "cpu" else 64
49
+ self .block_size = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
52
50
self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * batch_size
53
51
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
54
52
batch_size , - 1
Original file line number Diff line number Diff line change @@ -773,17 +773,19 @@ def attention_interface(
773
773
is_causal = True ,
774
774
)
775
775
self .use_sdpa = True
776
- elif self .has_flash_attn (query ) and past_len == 0 :
776
+ elif self .has_flash_attn (query ):
777
777
attn_output = torch .empty_like (query )
778
778
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
779
+ query_len_tensor = seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ]).int ()
780
+ query_max_len = input_lens .max () if past_len == 0 else 1
779
781
PagedAttention .flash_attn_varlen_func (
780
782
attn_output ,
781
783
query .contiguous () if query .device .type == "xpu" else query ,
782
784
key_cache .contiguous () if key_cache .device .type == "xpu" else key_cache ,
783
785
value_cache .contiguous () if value_cache .device .type == "xpu" else value_cache ,
786
+ query_len_tensor ,
784
787
seq_len_tensor ,
785
- seq_len_tensor ,
786
- input_lens .max (),
788
+ query_max_len ,
787
789
input_lens .max (),
788
790
1.0 / math .sqrt (self .head_dim ),
789
791
True ,
You can’t perform that action at this time.
0 commit comments