Skip to content

Commit 4af8fd0

Browse files
committed
fix beamsearch issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent d8cd7b9 commit 4af8fd0

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

optimum/exporters/ipex/cache_utils.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def update_for_prefill(
8888
all_slot_offsets = []
8989
num_blocks = (input_lens + self.block_size - 1) // self.block_size
9090
for i in range(batch_size):
91-
block_table = self.free_blocks.nonzero().view(-1)[0 : num_blocks[i]]
92-
self.block_tables[i][0 : num_blocks[i]] = block_table
91+
nb = num_blocks[i]
92+
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
93+
self.block_tables[i][0:nb] = block_table
9394
self.free_blocks[block_table] = 0
9495
slots_range = torch.arange(input_lens[i], device=key_states.device)
9596
block_indices = slots_range // self.block_size
@@ -202,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
202203
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
203204
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
204205
num_blocks = mask.cumsum(-1)[:, -1]
205-
updated_table = []
206+
updated_table = torch.zeros_like(beam_idx)
206207
for i in range(beam_idx.shape[0]):
207-
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
208-
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
209-
updated_table = torch.cat(tuple(updated_table), dim=0)
208+
nb = num_blocks[i]
209+
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
210+
updated_table[i] = self.block_tables[i][nb - 1]
210211
for layer_idx in range(self.num_hidden_layers):
211212
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
212213
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
213214
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
214-
self.free_blocks[free_table] = 1
215+
for i in free_table:
216+
if not (self.block_tables == i).any():
217+
self.free_blocks[i] = 1
215218

216219
def crop(self, maximum_length: int):
217220
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -231,4 +234,6 @@ def crop(self, maximum_length: int):
231234
self._seen_tokens[bs] = new_tokens
232235
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
233236
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
234-
self.free_blocks[free_table] = 1
237+
for i in free_table:
238+
if not (self.block_tables == i).any():
239+
self.free_blocks[i] = 1

0 commit comments

Comments
 (0)