@@ -34,22 +34,23 @@ class IPEXPagedCache(Cache):
34
34
def __init__ (
35
35
self ,
36
36
config : PretrainedConfig ,
37
- batch_size : int ,
37
+ max_batch_size : int ,
38
38
max_cache_len : int ,
39
39
device ,
40
40
dtype = None ,
41
41
layer_device_map = None ,
42
42
** kwargs ,
43
43
) -> None :
44
44
super ().__init__ ()
45
- self .batch_size = batch_size
45
+ self .max_batch_size = max_batch_size
46
46
# Used in `generate` to keep tally of how many tokens the cache has seen
47
- self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
47
+
48
+ self ._seen_tokens = torch .zeros ([max_batch_size ], dtype = torch .int32 , device = device )
48
49
default_block_size = 16 if device .type == "cpu" else 64
49
50
self .block_size = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
50
- self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * batch_size
51
+ self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * max_batch_size
51
52
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
52
- batch_size , - 1
53
+ max_batch_size , - 1
53
54
)
54
55
self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device )
55
56
self .max_cache_len = max_cache_len
@@ -193,7 +194,7 @@ def get_max_length(self) -> Optional[int]:
193
194
194
195
def reset (self ):
195
196
"""Resets the cache values while preserving the objects"""
196
- self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .block_tables .device )
197
+ self ._seen_tokens = torch .zeros ([self .max_batch_size ], dtype = torch .int32 , device = self .block_tables .device )
197
198
self .block_tables .fill_ (- 1 )
198
199
self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .block_tables .device )
199
200
self .max_seq_len = 0
0 commit comments