@@ -49,7 +49,7 @@ def __init__(
49
49
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
50
50
batch_size , - 1
51
51
)
52
- self .free_blocks = torch .arange ( self .num_blocks , device = device )
52
+ self .free_blocks = torch .ones ([ self .num_blocks ], dtype = torch . int32 , device = device )
53
53
self .max_cache_len = max_cache_len
54
54
self .num_kv_heads = config .num_key_value_heads
55
55
self .num_hidden_layers = config .num_hidden_layers
@@ -88,12 +88,9 @@ def update_for_prefill(
88
88
all_slot_offsets = []
89
89
num_blocks = (input_lens + self .block_size - 1 ) // self .block_size
90
90
for i in range (batch_size ):
91
- for b_idx in range (num_blocks [i ]):
92
- if self .block_tables [i ][b_idx ] == - 1 :
93
- # need a free block
94
- self .block_tables [i ][b_idx ] = self .free_blocks [0 ]
95
- self .free_blocks = self .free_blocks [1 :]
96
-
91
+ block_table = self .free_blocks .nonzero ().view (- 1 )[0 : num_blocks [i ]]
92
+ self .block_tables [i ][0 : num_blocks [i ]] = block_table
93
+ self .free_blocks [block_table ] = 0
97
94
slots_range = torch .arange (input_lens [i ], device = key_states .device )
98
95
block_indices = slots_range // self .block_size
99
96
slot_offsets = slots_range % self .block_size
@@ -103,7 +100,6 @@ def update_for_prefill(
103
100
all_block_indices = torch .cat (all_block_indices )
104
101
all_slot_offsets = torch .cat (all_slot_offsets )
105
102
self .slots = all_block_indices * self .block_size + all_slot_offsets
106
-
107
103
# Update the cache
108
104
PagedAttention .reshape_and_cache (
109
105
key_states ,
@@ -127,16 +123,16 @@ def update_for_decode(
127
123
):
128
124
if layer_idx == 0 :
129
125
start_block_idx = self ._seen_tokens // self .block_size
130
- num_blocks = (self ._seen_tokens + self .block_size ) // self .block_size
131
126
slot_offset_in_block = (self ._seen_tokens ) % self .block_size
132
127
self .slots = torch .zeros ([batch_size ], device = key_states .device , dtype = torch .int32 )
133
128
for i in range (batch_size ):
134
- for b_idx in range (start_block_idx [i ], num_blocks [i ]):
129
+ if slot_offset_in_block [i ] == 0 :
130
+ # need a new block:
131
+ b_idx = start_block_idx [i ]
135
132
if self .block_tables [i ][b_idx ] == - 1 :
136
133
# need a free block
137
- self .block_tables [i ][b_idx ] = self .free_blocks [0 ]
138
- self .free_blocks = self .free_blocks [1 :]
139
-
134
+ self .block_tables [i ][b_idx ] = self .free_blocks .nonzero ().view (- 1 )[0 :1 ]
135
+ self .free_blocks [self .block_tables [i ][b_idx ]] = 0
140
136
self .slots [i ] = self .block_tables [i ][start_block_idx [i ]] * self .block_size + slot_offset_in_block [i ]
141
137
# Update the cache
142
138
PagedAttention .reshape_and_cache (
@@ -196,7 +192,7 @@ def reset(self):
196
192
"""Resets the cache values while preserving the objects"""
197
193
self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .block_tables .device )
198
194
self .block_tables .fill_ (- 1 )
199
- self .free_blocks = torch .arange ( self .num_blocks , device = self .block_tables .device )
195
+ self .free_blocks = torch .ones ([ self .num_blocks ], dtype = torch . int32 , device = self .block_tables .device )
200
196
self .max_seq_len = 0
201
197
202
198
def reorder_cache (self , beam_idx : torch .LongTensor ):
@@ -215,7 +211,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
215
211
self .key_cache [layer_idx ][updated_table ] = self .key_cache [layer_idx ][updated_table [beam_idx ]]
216
212
self .value_cache [layer_idx ][updated_table ] = self .value_cache [layer_idx ][updated_table [beam_idx ]]
217
213
free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
218
- self .free_blocks = torch . cat (( self . free_blocks , free_table ))
214
+ self .free_blocks [ free_table ] = 1
219
215
220
216
def crop (self , maximum_length : int ):
221
217
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -235,4 +231,4 @@ def crop(self, maximum_length: int):
235
231
self ._seen_tokens [bs ] = new_tokens
236
232
self .max_seq_len , _ = self ._seen_tokens .max (dim = 0 )
237
233
free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
238
- self .free_blocks = torch . cat (( self . free_blocks , free_table ))
234
+ self .free_blocks [ free_table ] = 1
0 commit comments