Skip to content

Commit d8cd7b9

Browse files
committed
use free_table as a mask tensor
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 8ef3997 commit d8cd7b9

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

optimum/exporters/ipex/cache_utils.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
5050
batch_size, -1
5151
)
52-
self.free_blocks = torch.arange(self.num_blocks, device=device)
52+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
5353
self.max_cache_len = max_cache_len
5454
self.num_kv_heads = config.num_key_value_heads
5555
self.num_hidden_layers = config.num_hidden_layers
@@ -88,12 +88,9 @@ def update_for_prefill(
8888
all_slot_offsets = []
8989
num_blocks = (input_lens + self.block_size - 1) // self.block_size
9090
for i in range(batch_size):
91-
for b_idx in range(num_blocks[i]):
92-
if self.block_tables[i][b_idx] == -1:
93-
# need a free block
94-
self.block_tables[i][b_idx] = self.free_blocks[0]
95-
self.free_blocks = self.free_blocks[1:]
96-
91+
block_table = self.free_blocks.nonzero().view(-1)[0 : num_blocks[i]]
92+
self.block_tables[i][0 : num_blocks[i]] = block_table
93+
self.free_blocks[block_table] = 0
9794
slots_range = torch.arange(input_lens[i], device=key_states.device)
9895
block_indices = slots_range // self.block_size
9996
slot_offsets = slots_range % self.block_size
@@ -103,7 +100,6 @@ def update_for_prefill(
103100
all_block_indices = torch.cat(all_block_indices)
104101
all_slot_offsets = torch.cat(all_slot_offsets)
105102
self.slots = all_block_indices * self.block_size + all_slot_offsets
106-
107103
# Update the cache
108104
PagedAttention.reshape_and_cache(
109105
key_states,
@@ -127,16 +123,16 @@ def update_for_decode(
127123
):
128124
if layer_idx == 0:
129125
start_block_idx = self._seen_tokens // self.block_size
130-
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
131126
slot_offset_in_block = (self._seen_tokens) % self.block_size
132127
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
133128
for i in range(batch_size):
134-
for b_idx in range(start_block_idx[i], num_blocks[i]):
129+
if slot_offset_in_block[i] == 0:
130+
# need a new block:
131+
b_idx = start_block_idx[i]
135132
if self.block_tables[i][b_idx] == -1:
136133
# need a free block
137-
self.block_tables[i][b_idx] = self.free_blocks[0]
138-
self.free_blocks = self.free_blocks[1:]
139-
134+
self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
135+
self.free_blocks[self.block_tables[i][b_idx]] = 0
140136
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
141137
# Update the cache
142138
PagedAttention.reshape_and_cache(
@@ -196,7 +192,7 @@ def reset(self):
196192
"""Resets the cache values while preserving the objects"""
197193
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
198194
self.block_tables.fill_(-1)
199-
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
195+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
200196
self.max_seq_len = 0
201197

202198
def reorder_cache(self, beam_idx: torch.LongTensor):
@@ -215,7 +211,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
215211
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
216212
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
217213
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
218-
self.free_blocks = torch.cat((self.free_blocks, free_table))
214+
self.free_blocks[free_table] = 1
219215

220216
def crop(self, maximum_length: int):
221217
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -235,4 +231,4 @@ def crop(self, maximum_length: int):
235231
self._seen_tokens[bs] = new_tokens
236232
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
237233
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
238-
self.free_blocks = torch.cat((self.free_blocks, free_table))
234+
self.free_blocks[free_table] = 1

0 commit comments

Comments
 (0)