Skip to content

Commit 35cd0c1

Browse files
authoredOct 17, 2024
refine class IPEXPagedCache's update method (#945)
* refine class IPEXPagedCache's update method Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * replace tensor on xpu to List to avoid memory copy Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * split IPEXPagedCache's update function into `update_for_prefill` and `update_for_decode` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 541a236 commit 35cd0c1

File tree

2 files changed

+83
-35
lines changed

2 files changed

+83
-35
lines changed
 

‎optimum/exporters/ipex/cache_utils.py

100644100755
+78-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple
1+
from typing import List, Optional, Tuple
22

33
import torch
44
from intel_extension_for_pytorch.llm.modules import PagedAttention
@@ -95,14 +95,87 @@ def __init__(
9595
for _ in range(self.num_hidden_layers)
9696
]
9797

98+
def update_for_prefill(
99+
self,
100+
key_states: torch.Tensor,
101+
value_states: torch.Tensor,
102+
layer_idx: int,
103+
batch_size: int,
104+
length_list: Optional[List],
105+
):
106+
all_block_indices = []
107+
all_slot_offsets = []
108+
for i in range(batch_size):
109+
num_blocks = (length_list[i] + self.block_size - 1) // self.block_size
110+
for b_idx in range(num_blocks):
111+
if self.block_tables[i][b_idx] == -1:
112+
# need a free block
113+
self.block_tables[i][b_idx] = self.free_blocks.pop(0)
114+
115+
slots_range = torch.arange(length_list[i], device=key_states.device)
116+
block_indices = slots_range // self.block_size
117+
slot_offsets = slots_range % self.block_size
118+
all_block_indices.append(self.block_tables[i][block_indices])
119+
all_slot_offsets.append(slot_offsets)
120+
121+
all_block_indices = torch.cat(all_block_indices)
122+
all_slot_offsets = torch.cat(all_slot_offsets)
123+
slots_tensor = all_block_indices * self.block_size + all_slot_offsets
124+
# Update the cache
125+
PagedAttention.reshape_and_cache(
126+
key_states,
127+
value_states,
128+
self.kv_cache[layer_idx][0],
129+
self.kv_cache[layer_idx][1],
130+
slots_tensor,
131+
)
132+
133+
# Update the number of seen tokens
134+
if layer_idx == self.num_hidden_layers - 1:
135+
for i in range(batch_size):
136+
self._seen_tokens[i] += length_list[i]
137+
138+
def update_for_decode(
139+
self,
140+
key_states: torch.Tensor,
141+
value_states: torch.Tensor,
142+
layer_idx: int,
143+
batch_size: int,
144+
):
145+
slots = []
146+
for i in range(batch_size):
147+
start_block_idx = self._seen_tokens[i] // self.block_size
148+
num_blocks = (self._seen_tokens[i] + self.block_size) // self.block_size
149+
for b_idx in range(start_block_idx, num_blocks):
150+
if self.block_tables[i][b_idx] == -1:
151+
# need a free block
152+
self.block_tables[i][b_idx] = self.free_blocks.pop(0)
153+
block_idx = (self._seen_tokens[i]) // self.block_size
154+
slot_offset_in_block = (self._seen_tokens[i]) % self.block_size
155+
slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block)
156+
157+
# Update the cache
158+
PagedAttention.reshape_and_cache(
159+
key_states,
160+
value_states,
161+
self.kv_cache[layer_idx][0],
162+
self.kv_cache[layer_idx][1],
163+
torch.tensor(slots, device=key_states.device),
164+
)
165+
166+
# Update the number of seen tokens
167+
if layer_idx == self.num_hidden_layers - 1:
168+
for i in range(batch_size):
169+
self._seen_tokens[i] += 1
170+
98171
def update(
99172
self,
100173
key_states: torch.Tensor,
101174
value_states: torch.Tensor,
102175
layer_idx: int,
103176
attention_mask: torch.Tensor,
104177
position_ids: torch.Tensor,
105-
input_lens: torch.Tensor,
178+
length_list: Optional[List],
106179
) -> Tuple[torch.Tensor, torch.Tensor]:
107180
"""
108181
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -117,39 +190,14 @@ def update(
117190
Return:
118191
A tuple containing the updated key and value states.
119192
"""
193+
120194
batch_size = position_ids.shape[0]
121-
slots = []
122195
if self.get_seq_length() == 0:
123196
# prefill
124-
num_slots = input_lens.tolist()
197+
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, length_list)
125198
else:
126199
# decode
127-
num_slots = [1] * batch_size
128-
for i in range(batch_size):
129-
start_block_idx = self._seen_tokens[i] // self.block_size
130-
num_blocks = (self._seen_tokens[i] + num_slots[i] + self.block_size - 1) // self.block_size
131-
for b_idx in range(start_block_idx, num_blocks):
132-
if self.block_tables[i][b_idx] == -1:
133-
# need a free block
134-
self.block_tables[i][b_idx] = self.free_blocks.pop(0)
135-
for slot in range(num_slots[i]):
136-
block_idx = (self._seen_tokens[i] + slot) // self.block_size
137-
slot_offset_in_block = (self._seen_tokens[i] + slot) % self.block_size
138-
slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block)
139-
140-
# Update the cache
141-
PagedAttention.reshape_and_cache(
142-
key_states,
143-
value_states,
144-
self.kv_cache[layer_idx][0],
145-
self.kv_cache[layer_idx][1],
146-
torch.tensor(slots, device=key_states.device),
147-
)
148-
149-
# Update the number of seen tokens
150-
if layer_idx == self.num_hidden_layers - 1:
151-
for i in range(batch_size):
152-
self._seen_tokens[i] += num_slots[i]
200+
self.update_for_decode(key_states, value_states, layer_idx, batch_size)
153201

154202
return self.kv_cache[layer_idx][0], self.kv_cache[layer_idx][1]
155203

‎optimum/exporters/ipex/modeling_utils.py

100644100755
+5-5
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _llama_model_forward(
123123
else:
124124
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
125125
input_lens = attention_mask.cumsum(-1)[:, -1]
126-
126+
lens_list = input_lens.tolist()
127127
for idx, decoder_layer in enumerate(self.layers):
128128
if output_hidden_states:
129129
all_hidden_states += (hidden_states,)
@@ -137,6 +137,7 @@ def _llama_model_forward(
137137
use_cache=use_cache,
138138
position_embeddings=position_embeddings,
139139
input_lens=input_lens.int(),
140+
lens_list=lens_list,
140141
)
141142

142143
hidden_states = layer_outputs[0]
@@ -210,6 +211,7 @@ def forward(
210211
output_attentions: bool = False,
211212
use_cache: bool = False,
212213
input_lens: Optional[torch.Tensor] = None,
214+
lens_list: Optional[List] = None,
213215
**kwargs,
214216
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
215217
if past_key_value is None and kwargs.get("layer_past", None) is not None:
@@ -227,15 +229,13 @@ def forward(
227229

228230
if past_key_value is not None:
229231
key_cache, value_cache = past_key_value.update(
230-
key, value, self.layer_idx, attention_mask, position_ids, input_lens
232+
key, value, self.layer_idx, attention_mask, position_ids, lens_list
231233
)
232234

233235
attn_output = torch.empty_like(query)
234236
if past_len == 0:
235237
# prefill, remove padding
236-
seq_len_tensor = torch.cat(
237-
(torch.tensor([0], device=input_lens.device, dtype=torch.int), input_lens.cumsum(-1).int())
238-
)
238+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
239239
varlen_attention(
240240
query.contiguous() if query.device.type == "xpu" else query,
241241
key.contiguous() if key.device.type == "xpu" else key,

0 commit comments

Comments
 (0)