Skip to content

Commit de8c9d7

Browse files
authored
add support for flash decoding on xpu (#1118)
* add support for flash decoding on xpu Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * change block_size to 16 for xpu, as `single_query_cached_kv_attention` API does not support 64 Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * optimize the performance Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix format issue Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * refine code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add support for flash decoding on xpu Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix conflict CI Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix bug Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix CI bug Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 63bee4e commit de8c9d7

File tree

3 files changed

+114
-39
lines changed

3 files changed

+114
-39
lines changed

optimum/exporters/ipex/cache_utils.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from intel_extension_for_pytorch.llm.modules import PagedAttention
66
from transformers import Cache, PretrainedConfig
77

8+
from optimum.intel.utils.import_utils import is_ipex_version
9+
810

911
class IPEXPagedCache(Cache):
1012
"""
@@ -43,10 +45,14 @@ def __init__(
4345
) -> None:
4446
super().__init__()
4547
self.max_batch_size = max_batch_size
48+
self.device = device
49+
self._supports_flash_decoding = (
50+
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
51+
)
4652
# Used in `generate` to keep tally of how many tokens the cache has seen
4753

4854
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
49-
default_block_size = 16 if device.type == "cpu" else 64
55+
default_block_size = 16
5056
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
5157
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
5258
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
@@ -70,14 +76,44 @@ def __init__(
7076
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7177
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7278
elif device.type == "xpu":
73-
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
74-
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
79+
if self._supports_flash_decoding:
80+
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
81+
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
82+
else:
83+
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
84+
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
7585
for i in range(config.num_hidden_layers):
7686
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
7787
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
7888
self.key_cache.append(new_layer_key_cache)
7989
self.value_cache.append(new_layer_value_cache)
8090

91+
def reshape_and_cache(
92+
self,
93+
key: torch.Tensor,
94+
value: torch.Tensor,
95+
key_cache: torch.Tensor,
96+
value_cache: torch.Tensor,
97+
slots: torch.Tensor,
98+
):
99+
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
100+
if self.device.type == "xpu" and self._supports_flash_decoding:
101+
PagedAttention.reshape_and_cache_flash(
102+
key,
103+
value,
104+
key_cache,
105+
value_cache,
106+
slots,
107+
)
108+
else:
109+
PagedAttention.reshape_and_cache(
110+
key,
111+
value,
112+
key_cache,
113+
value_cache,
114+
slots,
115+
)
116+
81117
def update_for_prefill(
82118
self,
83119
key_states: torch.Tensor,
@@ -95,7 +131,7 @@ def update_for_prefill(
95131
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
96132
self.block_tables[i][0:nb] = block_table
97133
self.free_blocks[block_table] = 0
98-
slots_range = torch.arange(input_lens[i], device=key_states.device)
134+
slots_range = torch.arange(input_lens[i], device=self.device)
99135
block_indices = slots_range // self.block_size
100136
slot_offsets = slots_range % self.block_size
101137
all_block_indices.append(self.block_tables[i][block_indices])
@@ -105,12 +141,8 @@ def update_for_prefill(
105141
all_slot_offsets = torch.cat(all_slot_offsets)
106142
self.slots = all_block_indices * self.block_size + all_slot_offsets
107143
# Update the cache
108-
PagedAttention.reshape_and_cache(
109-
key_states,
110-
value_states,
111-
self.key_cache[layer_idx],
112-
self.value_cache[layer_idx],
113-
self.slots,
144+
self.reshape_and_cache(
145+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
114146
)
115147

116148
# Update the number of seen tokens
@@ -128,7 +160,7 @@ def update_for_decode(
128160
if layer_idx == 0:
129161
start_block_idx = self._seen_tokens // self.block_size
130162
slot_offset_in_block = (self._seen_tokens) % self.block_size
131-
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
163+
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
132164
for i in range(batch_size):
133165
if slot_offset_in_block[i] == 0:
134166
# need a new block:
@@ -139,12 +171,8 @@ def update_for_decode(
139171
self.free_blocks[self.block_tables[i][b_idx]] = 0
140172
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
141173
# Update the cache
142-
PagedAttention.reshape_and_cache(
143-
key_states,
144-
value_states,
145-
self.key_cache[layer_idx],
146-
self.value_cache[layer_idx],
147-
self.slots,
174+
self.reshape_and_cache(
175+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
148176
)
149177

150178
# Update the number of seen tokens
@@ -194,16 +222,15 @@ def get_max_length(self) -> Optional[int]:
194222

195223
def reset(self):
196224
"""Resets the cache values while preserving the objects"""
197-
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.block_tables.device)
225+
self._seen_tokens = torch.zeros([self.max_batch_size], dtype=torch.int32, device=self.device)
198226
self.block_tables.fill_(-1)
199-
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
227+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
200228
self.max_seq_len = 0
201229

202230
def reorder_cache(self, beam_idx: torch.LongTensor):
203231
"""Reorders the cache for beam search, given the selected beam indices."""
204-
device = self.block_tables.device
205232
origin_table = self.block_tables.clone()
206-
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
233+
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
207234
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
208235
num_blocks = mask.cumsum(-1)[:, -1]
209236
updated_table = torch.zeros_like(beam_idx)

optimum/exporters/ipex/modeling_utils.py

+65-17
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _llama_model_forward(
206206

207207
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
208208

209+
device = input_ids.device if input_ids is not None else inputs_embeds.device
209210
if position_ids is None:
210-
device = input_ids.device if input_ids is not None else inputs_embeds.device
211211
position_ids = torch.arange(
212212
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
213213
)
@@ -227,6 +227,9 @@ def _llama_model_forward(
227227
position_embeddings = self.rotary_emb(hidden_states, position_ids)
228228

229229
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
230+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
231+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
232+
max_input_lens = input_lens.max().item()
230233

231234
if past_key_values_length == 0 and past_key_values is not None:
232235
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -262,6 +265,9 @@ def _llama_model_forward(
262265
use_cache=use_cache,
263266
position_embeddings=position_embeddings,
264267
input_lens=input_lens,
268+
max_input_lens=max_input_lens,
269+
seq_len_tensor=seq_len_tensor,
270+
query_len_tensor=query_len_tensor,
265271
)
266272

267273
hidden_states = layer_outputs[0]
@@ -330,11 +336,10 @@ def _falcon_model_forward(
330336

331337
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
332338
batch_size, seq_length, _ = inputs_embeds.shape
339+
device = input_ids.device if input_ids is not None else inputs_embeds.device
333340

334341
if cache_position is None:
335-
cache_position = torch.arange(
336-
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
337-
)
342+
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
338343

339344
if position_ids is None:
340345
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
@@ -350,6 +355,9 @@ def _falcon_model_forward(
350355
position_embeddings = self.rotary_emb(hidden_states, position_ids)
351356

352357
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
358+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
359+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
360+
max_input_lens = input_lens.max().item()
353361

354362
if past_key_values_length == 0 and past_key_values is not None:
355363
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -392,6 +400,9 @@ def _falcon_model_forward(
392400
cache_position=cache_position,
393401
position_embeddings=position_embeddings,
394402
input_lens=input_lens,
403+
max_input_lens=max_input_lens,
404+
seq_len_tensor=seq_len_tensor,
405+
query_len_tensor=query_len_tensor,
395406
)
396407

397408
hidden_states = outputs[0]
@@ -486,6 +497,9 @@ def _gpt2_model_forward(
486497
hidden_states = self.drop(hidden_states)
487498

488499
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
500+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
501+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
502+
max_input_lens = input_lens.max().item()
489503

490504
if past_length == 0 and past_key_values is not None:
491505
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -521,6 +535,9 @@ def _gpt2_model_forward(
521535
use_cache=use_cache,
522536
output_attentions=output_attentions,
523537
input_lens=input_lens,
538+
max_input_lens=max_input_lens,
539+
seq_len_tensor=seq_len_tensor,
540+
query_len_tensor=query_len_tensor,
524541
)
525542

526543
hidden_states = outputs[0]
@@ -591,6 +608,7 @@ def _qwen2_model_forward(
591608
inputs_embeds = self.embed_tokens(input_ids)
592609

593610
batch_size, seq_length = inputs_embeds.shape[:2]
611+
device = input_ids.device if input_ids is not None else inputs_embeds.device
594612

595613
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
596614
if cache_position is None:
@@ -615,6 +633,9 @@ def _qwen2_model_forward(
615633
position_embeddings = self.rotary_emb(hidden_states, position_ids)
616634

617635
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
636+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
637+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
638+
max_input_lens = input_lens.max().item()
618639

619640
if past_key_values_length == 0 and past_key_values is not None:
620641
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -650,6 +671,9 @@ def _qwen2_model_forward(
650671
cache_position=cache_position,
651672
position_embeddings=position_embeddings,
652673
input_lens=input_lens,
674+
max_input_lens=max_input_lens,
675+
seq_len_tensor=seq_len_tensor,
676+
query_len_tensor=query_len_tensor,
653677
**kwargs,
654678
)
655679

@@ -704,14 +728,26 @@ def postprocess_attention_output(self, attn_output):
704728
return attn_output
705729

706730
# Maybe removed after torch 2.6 released
707-
def has_flash_attn(self, query):
708-
if query.device.type == "cpu":
731+
def has_flash_attn(self):
732+
if self.module_device.type == "cpu":
709733
return is_torch_version(">", "2.4.99")
710-
elif query.device.type == "xpu":
734+
elif self.module_device.type == "xpu":
711735
return is_torch_version(">", "2.5.99")
712736

713737
def attention_interface(
714-
self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
738+
self,
739+
query,
740+
key_cache,
741+
value_cache,
742+
key,
743+
value,
744+
past_key_value,
745+
attention_mask,
746+
input_lens,
747+
past_len,
748+
seq_len_tensor,
749+
query_len_tensor,
750+
max_input_lens,
715751
):
716752
if past_key_value is None:
717753
n_rep = query.shape[1] // key.shape[1]
@@ -728,20 +764,19 @@ def attention_interface(
728764
is_causal=True,
729765
)
730766
self.use_sdpa = True
731-
elif self.has_flash_attn(query):
767+
elif self.has_flash_attn():
732768
attn_output = torch.empty_like(query)
733-
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
734-
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
735-
query_max_len = input_lens.max() if past_len == 0 else 1
769+
query_len_tensor = seq_len_tensor if past_len == 0 else query_len_tensor
770+
query_max_len = max_input_lens if past_len == 0 else 1
736771
PagedAttention.flash_attn_varlen_func(
737772
attn_output,
738773
query.contiguous() if query.device.type == "xpu" else query,
739-
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
740-
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
774+
key_cache,
775+
value_cache,
741776
query_len_tensor,
742777
seq_len_tensor,
743778
query_max_len,
744-
input_lens.max(),
779+
max_input_lens,
745780
1.0 / math.sqrt(self.head_dim),
746781
True,
747782
past_key_value.block_tables,
@@ -750,7 +785,6 @@ def attention_interface(
750785
elif past_len == 0:
751786
# prefill, remove padding
752787
attn_output = torch.empty_like(query)
753-
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
754788
varlen_attention(
755789
query.contiguous() if query.device.type == "xpu" else query,
756790
key.contiguous() if key.device.type == "xpu" else key,
@@ -799,6 +833,9 @@ def forward(
799833
if past_key_value is None and kwargs.get("layer_past", None) is not None:
800834
past_key_value = kwargs.pop("layer_past", None)
801835
input_lens = kwargs.pop("input_lens", None)
836+
seq_len_tensor = kwargs.pop("seq_len_tensor", None)
837+
query_len_tensor = kwargs.pop("query_len_tensor", None)
838+
max_input_lens = kwargs.pop("max_input_lens", 0)
802839
past_len = 0
803840
if past_key_value is not None:
804841
past_len = past_key_value.get_seq_length()
@@ -810,7 +847,18 @@ def forward(
810847
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)
811848

812849
attn_output = self.attention_interface(
813-
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
850+
query,
851+
key_cache,
852+
value_cache,
853+
key,
854+
value,
855+
past_key_value,
856+
attention_mask,
857+
input_lens,
858+
past_len,
859+
seq_len_tensor,
860+
query_len_tensor,
861+
max_input_lens,
814862
)
815863

816864
attn_output = self.postprocess_attention_output(attn_output)

tests/ipex/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_forward(self, model_arch):
294294
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
295295
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
296296
self.assertIsInstance(ipex_model.config, PretrainedConfig)
297-
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long)
297+
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long).to(DEVICE)
298298
outputs = ipex_model(input_ids)
299299

300300
self.assertIsInstance(outputs.logits, torch.Tensor)

0 commit comments

Comments
 (0)