File tree 1 file changed +2
-5
lines changed
1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change 6
6
from transformers import Cache , PretrainedConfig
7
7
8
8
9
- # Recommend 16 on CPU and 64 on XPU.
10
- BLOCK_SIZE = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , 16 ))
11
-
12
-
13
9
class IPEXPagedCache (Cache ):
14
10
"""
15
11
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
@@ -49,7 +45,8 @@ def __init__(
49
45
self .batch_size = batch_size
50
46
# Used in `generate` to keep tally of how many tokens the cache has seen
51
47
self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
52
- self .block_size = BLOCK_SIZE
48
+ default_block_size = 16 if device .type == "cpu" else 64
49
+ self .block_size = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
53
50
self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * batch_size
54
51
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
55
52
batch_size , - 1
You can’t perform that action at this time.
0 commit comments