Skip to content

Commit f96edd2

Browse files
committed
add support for flash decoding on xpu
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 2590794 commit f96edd2

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

optimum/exporters/ipex/cache_utils.py

+46-20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from intel_extension_for_pytorch.llm.modules import PagedAttention
66
from transformers import Cache, PretrainedConfig
77

8+
from optimum.intel.utils.import_utils import is_ipex_version
9+
810

911
class IPEXPagedCache(Cache):
1012
"""
@@ -43,6 +45,10 @@ def __init__(
4345
) -> None:
4446
super().__init__()
4547
self.batch_size = batch_size
48+
self.device = device
49+
self.flash_decoding = (
50+
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
51+
)
4652
# Used in `generate` to keep tally of how many tokens the cache has seen
4753
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
4854
default_block_size = 16 if device.type == "cpu" else 64
@@ -69,14 +75,43 @@ def __init__(
6975
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7076
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
7177
elif device.type == "xpu":
72-
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
73-
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
78+
if self.flash_decoding:
79+
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
80+
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
81+
else:
82+
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
83+
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
7484
for i in range(config.num_hidden_layers):
7585
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
7686
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
7787
self.key_cache.append(new_layer_key_cache)
7888
self.value_cache.append(new_layer_value_cache)
7989

90+
def reshape_and_cache(
91+
self,
92+
key: torch.Tensor,
93+
value: torch.Tensor,
94+
key_cache: torch.Tensor,
95+
value_cache: torch.Tensor,
96+
slots: torch.Tensor,
97+
):
98+
if self.device.type == "xpu" and self.flash_decoding:
99+
PagedAttention.reshape_and_cache_flash(
100+
key,
101+
value,
102+
key_cache,
103+
value_cache,
104+
slots,
105+
)
106+
else:
107+
PagedAttention.reshape_and_cache(
108+
key,
109+
value,
110+
key_cache,
111+
value_cache,
112+
slots,
113+
)
114+
80115
def update_for_prefill(
81116
self,
82117
key_states: torch.Tensor,
@@ -94,7 +129,7 @@ def update_for_prefill(
94129
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
95130
self.block_tables[i][0:nb] = block_table
96131
self.free_blocks[block_table] = 0
97-
slots_range = torch.arange(input_lens[i], device=key_states.device)
132+
slots_range = torch.arange(input_lens[i], device=self.device)
98133
block_indices = slots_range // self.block_size
99134
slot_offsets = slots_range % self.block_size
100135
all_block_indices.append(self.block_tables[i][block_indices])
@@ -104,12 +139,8 @@ def update_for_prefill(
104139
all_slot_offsets = torch.cat(all_slot_offsets)
105140
self.slots = all_block_indices * self.block_size + all_slot_offsets
106141
# Update the cache
107-
PagedAttention.reshape_and_cache(
108-
key_states,
109-
value_states,
110-
self.key_cache[layer_idx],
111-
self.value_cache[layer_idx],
112-
self.slots,
142+
self.reshape_and_cache(
143+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
113144
)
114145

115146
# Update the number of seen tokens
@@ -127,7 +158,7 @@ def update_for_decode(
127158
if layer_idx == 0:
128159
start_block_idx = self._seen_tokens // self.block_size
129160
slot_offset_in_block = (self._seen_tokens) % self.block_size
130-
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
161+
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
131162
for i in range(batch_size):
132163
if slot_offset_in_block[i] == 0:
133164
# need a new block:
@@ -138,12 +169,8 @@ def update_for_decode(
138169
self.free_blocks[self.block_tables[i][b_idx]] = 0
139170
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
140171
# Update the cache
141-
PagedAttention.reshape_and_cache(
142-
key_states,
143-
value_states,
144-
self.key_cache[layer_idx],
145-
self.value_cache[layer_idx],
146-
self.slots,
172+
self.reshape_and_cache(
173+
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
147174
)
148175

149176
# Update the number of seen tokens
@@ -193,16 +220,15 @@ def get_max_length(self) -> Optional[int]:
193220

194221
def reset(self):
195222
"""Resets the cache values while preserving the objects"""
196-
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
223+
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.device)
197224
self.block_tables.fill_(-1)
198-
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
225+
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
199226
self.max_seq_len = 0
200227

201228
def reorder_cache(self, beam_idx: torch.LongTensor):
202229
"""Reorders the cache for beam search, given the selected beam indices."""
203-
device = self.block_tables.device
204230
origin_table = self.block_tables.clone()
205-
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
231+
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
206232
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
207233
num_blocks = mask.cumsum(-1)[:, -1]
208234
updated_table = torch.zeros_like(beam_idx)

optimum/exporters/ipex/modeling_utils.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,10 @@ def postprocess_attention_output(self, attn_output):
628628
return attn_output
629629

630630
# Maybe removed after torch 2.6 released
631-
def has_flash_attn(self, query):
632-
if query.device.type == "cpu":
631+
def has_flash_attn(self):
632+
if self.module_device.type == "cpu":
633633
return is_torch_version(">", "2.4.99")
634-
elif query.device.type == "xpu":
634+
elif self.module_device.type == "xpu":
635635
return is_torch_version(">", "2.5.99")
636636

637637
def attention_interface(
@@ -652,20 +652,23 @@ def attention_interface(
652652
is_causal=True,
653653
)
654654
self.use_sdpa = True
655-
elif self.has_flash_attn(query):
655+
elif self.has_flash_attn():
656656
attn_output = torch.empty_like(query)
657657
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
658-
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
659-
query_max_len = input_lens.max() if past_len == 0 else 1
658+
query_len_tensor = (
659+
seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int()
660+
)
661+
max_input_lens = input_lens.max().item()
662+
query_max_len = max_input_lens if past_len == 0 else 1
660663
PagedAttention.flash_attn_varlen_func(
661664
attn_output,
662665
query.contiguous() if query.device.type == "xpu" else query,
663-
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
664-
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
666+
key_cache,
667+
value_cache,
665668
query_len_tensor,
666669
seq_len_tensor,
667670
query_max_len,
668-
input_lens.max(),
671+
max_input_lens,
669672
1.0 / math.sqrt(self.head_dim),
670673
True,
671674
past_key_value.block_tables,

0 commit comments

Comments
 (0)