@@ -88,8 +88,9 @@ def update_for_prefill(
88
88
all_slot_offsets = []
89
89
num_blocks = (input_lens + self .block_size - 1 ) // self .block_size
90
90
for i in range (batch_size ):
91
- block_table = self .free_blocks .nonzero ().view (- 1 )[0 : num_blocks [i ]]
92
- self .block_tables [i ][0 : num_blocks [i ]] = block_table
91
+ nb = num_blocks [i ]
92
+ block_table = self .free_blocks .nonzero ().view (- 1 )[0 :nb ]
93
+ self .block_tables [i ][0 :nb ] = block_table
93
94
self .free_blocks [block_table ] = 0
94
95
slots_range = torch .arange (input_lens [i ], device = key_states .device )
95
96
block_indices = slots_range // self .block_size
@@ -202,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
202
203
updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (device ))
203
204
mask = self .block_tables .masked_fill (self .block_tables != - 1 , 1 ).masked_fill (self .block_tables == - 1 , 0 )
204
205
num_blocks = mask .cumsum (- 1 )[:, - 1 ]
205
- updated_table = []
206
+ updated_table = torch . zeros_like ( beam_idx )
206
207
for i in range (beam_idx .shape [0 ]):
207
- self . block_tables [ i , 0 : num_blocks [ i ] - 1 ] = updated_block_tables [ i , 0 : num_blocks [i ] - 1 ]
208
- updated_table . append ( self .block_tables [i : i + 1 , num_blocks [ i ] - 1 : num_blocks [ i ]])
209
- updated_table = torch . cat ( tuple ( updated_table ), dim = 0 )
208
+ nb = num_blocks [i ]
209
+ self .block_tables [i , 0 : nb - 1 ] = updated_block_tables [ i , 0 : nb - 1 ]
210
+ updated_table [ i ] = self . block_tables [ i ][ nb - 1 ]
210
211
for layer_idx in range (self .num_hidden_layers ):
211
212
self .key_cache [layer_idx ][updated_table ] = self .key_cache [layer_idx ][updated_table [beam_idx ]]
212
213
self .value_cache [layer_idx ][updated_table ] = self .value_cache [layer_idx ][updated_table [beam_idx ]]
213
214
free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
214
- self .free_blocks [free_table ] = 1
215
+ for i in free_table :
216
+ if not (self .block_tables == i ).any ():
217
+ self .free_blocks [i ] = 1
215
218
216
219
def crop (self , maximum_length : int ):
217
220
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -231,4 +234,6 @@ def crop(self, maximum_length: int):
231
234
self ._seen_tokens [bs ] = new_tokens
232
235
self .max_seq_len , _ = self ._seen_tokens .max (dim = 0 )
233
236
free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
234
- self .free_blocks [free_table ] = 1
237
+ for i in free_table :
238
+ if not (self .block_tables == i ).any ():
239
+ self .free_blocks [i ] = 1
0 commit comments