5
5
from intel_extension_for_pytorch .llm .modules import PagedAttention
6
6
from transformers import Cache , PretrainedConfig
7
7
8
+ from optimum .intel .utils .import_utils import is_ipex_version
9
+
8
10
9
11
class IPEXPagedCache (Cache ):
10
12
"""
@@ -43,10 +45,14 @@ def __init__(
43
45
) -> None :
44
46
super ().__init__ ()
45
47
self .max_batch_size = max_batch_size
48
+ self .device = device
49
+ self ._supports_flash_decoding = (
50
+ is_ipex_version (">" , "2.4.99" ) if device .type == "cpu" else is_ipex_version (">" , "2.5.99" )
51
+ )
46
52
# Used in `generate` to keep tally of how many tokens the cache has seen
47
53
48
54
self ._seen_tokens = torch .zeros ([max_batch_size ], dtype = torch .int32 , device = device )
49
- default_block_size = 16 if device . type == "cpu" else 64
55
+ default_block_size = 16
50
56
self .block_size = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
51
57
self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * max_batch_size
52
58
self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
@@ -70,14 +76,44 @@ def __init__(
70
76
key_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
71
77
value_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
72
78
elif device .type == "xpu" :
73
- key_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size , 1 )
74
- value_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size )
79
+ if self ._supports_flash_decoding :
80
+ key_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
81
+ value_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
82
+ else :
83
+ key_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size , 1 )
84
+ value_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size )
75
85
for i in range (config .num_hidden_layers ):
76
86
new_layer_key_cache = torch .zeros (key_cache_shape , dtype = dtype , device = device )
77
87
new_layer_value_cache = torch .zeros (value_cache_shape , dtype = dtype , device = device )
78
88
self .key_cache .append (new_layer_key_cache )
79
89
self .value_cache .append (new_layer_value_cache )
80
90
91
+ def reshape_and_cache (
92
+ self ,
93
+ key : torch .Tensor ,
94
+ value : torch .Tensor ,
95
+ key_cache : torch .Tensor ,
96
+ value_cache : torch .Tensor ,
97
+ slots : torch .Tensor ,
98
+ ):
99
+ # TODO: unify API definition between CPU and XPU in IPEX version > 2.6
100
+ if self .device .type == "xpu" and self ._supports_flash_decoding :
101
+ PagedAttention .reshape_and_cache_flash (
102
+ key ,
103
+ value ,
104
+ key_cache ,
105
+ value_cache ,
106
+ slots ,
107
+ )
108
+ else :
109
+ PagedAttention .reshape_and_cache (
110
+ key ,
111
+ value ,
112
+ key_cache ,
113
+ value_cache ,
114
+ slots ,
115
+ )
116
+
81
117
def update_for_prefill (
82
118
self ,
83
119
key_states : torch .Tensor ,
@@ -95,7 +131,7 @@ def update_for_prefill(
95
131
block_table = self .free_blocks .nonzero ().view (- 1 )[0 :nb ]
96
132
self .block_tables [i ][0 :nb ] = block_table
97
133
self .free_blocks [block_table ] = 0
98
- slots_range = torch .arange (input_lens [i ], device = key_states .device )
134
+ slots_range = torch .arange (input_lens [i ], device = self .device )
99
135
block_indices = slots_range // self .block_size
100
136
slot_offsets = slots_range % self .block_size
101
137
all_block_indices .append (self .block_tables [i ][block_indices ])
@@ -105,12 +141,8 @@ def update_for_prefill(
105
141
all_slot_offsets = torch .cat (all_slot_offsets )
106
142
self .slots = all_block_indices * self .block_size + all_slot_offsets
107
143
# Update the cache
108
- PagedAttention .reshape_and_cache (
109
- key_states ,
110
- value_states ,
111
- self .key_cache [layer_idx ],
112
- self .value_cache [layer_idx ],
113
- self .slots ,
144
+ self .reshape_and_cache (
145
+ key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots
114
146
)
115
147
116
148
# Update the number of seen tokens
@@ -128,7 +160,7 @@ def update_for_decode(
128
160
if layer_idx == 0 :
129
161
start_block_idx = self ._seen_tokens // self .block_size
130
162
slot_offset_in_block = (self ._seen_tokens ) % self .block_size
131
- self .slots = torch .zeros ([batch_size ], device = key_states .device , dtype = torch .int32 )
163
+ self .slots = torch .zeros ([batch_size ], device = self .device , dtype = torch .int32 )
132
164
for i in range (batch_size ):
133
165
if slot_offset_in_block [i ] == 0 :
134
166
# need a new block:
@@ -139,12 +171,8 @@ def update_for_decode(
139
171
self .free_blocks [self .block_tables [i ][b_idx ]] = 0
140
172
self .slots [i ] = self .block_tables [i ][start_block_idx [i ]] * self .block_size + slot_offset_in_block [i ]
141
173
# Update the cache
142
- PagedAttention .reshape_and_cache (
143
- key_states ,
144
- value_states ,
145
- self .key_cache [layer_idx ],
146
- self .value_cache [layer_idx ],
147
- self .slots ,
174
+ self .reshape_and_cache (
175
+ key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots
148
176
)
149
177
150
178
# Update the number of seen tokens
@@ -194,16 +222,15 @@ def get_max_length(self) -> Optional[int]:
194
222
195
223
def reset (self ):
196
224
"""Resets the cache values while preserving the objects"""
197
- self ._seen_tokens = torch .zeros ([self .max_batch_size ], dtype = torch .int32 , device = self .block_tables . device )
225
+ self ._seen_tokens = torch .zeros ([self .max_batch_size ], dtype = torch .int32 , device = self .device )
198
226
self .block_tables .fill_ (- 1 )
199
- self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .block_tables . device )
227
+ self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .device )
200
228
self .max_seq_len = 0
201
229
202
230
def reorder_cache (self , beam_idx : torch .LongTensor ):
203
231
"""Reorders the cache for beam search, given the selected beam indices."""
204
- device = self .block_tables .device
205
232
origin_table = self .block_tables .clone ()
206
- updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (device ))
233
+ updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (self . device ))
207
234
mask = self .block_tables .masked_fill (self .block_tables != - 1 , 1 ).masked_fill (self .block_tables == - 1 , 0 )
208
235
num_blocks = mask .cumsum (- 1 )[:, - 1 ]
209
236
updated_table = torch .zeros_like (beam_idx )
0 commit comments