File tree 1 file changed +5
-1
lines changed
1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 5
5
from transformers import Cache , PretrainedConfig
6
6
7
7
8
+ # May need to tune based on sequence length and different models but default to 16 currently.
9
+ BLOCK_SIZE = 16
10
+
11
+
8
12
class IPEXPagedCache (Cache ):
9
13
"""
10
14
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
@@ -44,7 +48,7 @@ def __init__(
44
48
self .batch_size = batch_size
45
49
# Used in `generate` to keep tally of how many tokens the cache has seen
46
50
self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
47
- self .block_size = 64
51
+ self .block_size = BLOCK_SIZE
48
52
self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * batch_size
49
53
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
50
54
batch_size , - 1
You can’t perform that action at this time.
0 commit comments