5
5
from intel_extension_for_pytorch .llm .modules import PagedAttention
6
6
from transformers import Cache , PretrainedConfig
7
7
8
+ from optimum .intel .utils .import_utils import is_ipex_version
9
+
8
10
9
11
class IPEXPagedCache (Cache ):
10
12
"""
@@ -43,6 +45,10 @@ def __init__(
43
45
) -> None :
44
46
super ().__init__ ()
45
47
self .batch_size = batch_size
48
+ self .device = device
49
+ self .flash_decoding = (
50
+ is_ipex_version (">" , "2.4.99" ) if device .type == "cpu" else is_ipex_version (">" , "2.5.99" )
51
+ )
46
52
# Used in `generate` to keep tally of how many tokens the cache has seen
47
53
self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
48
54
default_block_size = 16 if device .type == "cpu" else 64
@@ -69,14 +75,43 @@ def __init__(
69
75
key_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
70
76
value_cache_shape = (self .num_blocks , self .num_kv_heads , self .block_size , head_size )
71
77
elif device .type == "xpu" :
72
- key_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size , 1 )
73
- value_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size )
78
+ if self .flash_decoding :
79
+ key_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
80
+ value_cache_shape = (self .num_blocks , self .block_size , self .num_kv_heads , head_size )
81
+ else :
82
+ key_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size , 1 )
83
+ value_cache_shape = (self .num_blocks , self .num_kv_heads , head_size , self .block_size )
74
84
for i in range (config .num_hidden_layers ):
75
85
new_layer_key_cache = torch .zeros (key_cache_shape , dtype = dtype , device = device )
76
86
new_layer_value_cache = torch .zeros (value_cache_shape , dtype = dtype , device = device )
77
87
self .key_cache .append (new_layer_key_cache )
78
88
self .value_cache .append (new_layer_value_cache )
79
89
90
+ def reshape_and_cache (
91
+ self ,
92
+ key : torch .Tensor ,
93
+ value : torch .Tensor ,
94
+ key_cache : torch .Tensor ,
95
+ value_cache : torch .Tensor ,
96
+ slots : torch .Tensor ,
97
+ ):
98
+ if self .device .type == "xpu" and self .flash_decoding :
99
+ PagedAttention .reshape_and_cache_flash (
100
+ key ,
101
+ value ,
102
+ key_cache ,
103
+ value_cache ,
104
+ slots ,
105
+ )
106
+ else :
107
+ PagedAttention .reshape_and_cache (
108
+ key ,
109
+ value ,
110
+ key_cache ,
111
+ value_cache ,
112
+ slots ,
113
+ )
114
+
80
115
def update_for_prefill (
81
116
self ,
82
117
key_states : torch .Tensor ,
@@ -94,7 +129,7 @@ def update_for_prefill(
94
129
block_table = self .free_blocks .nonzero ().view (- 1 )[0 :nb ]
95
130
self .block_tables [i ][0 :nb ] = block_table
96
131
self .free_blocks [block_table ] = 0
97
- slots_range = torch .arange (input_lens [i ], device = key_states .device )
132
+ slots_range = torch .arange (input_lens [i ], device = self .device )
98
133
block_indices = slots_range // self .block_size
99
134
slot_offsets = slots_range % self .block_size
100
135
all_block_indices .append (self .block_tables [i ][block_indices ])
@@ -104,12 +139,8 @@ def update_for_prefill(
104
139
all_slot_offsets = torch .cat (all_slot_offsets )
105
140
self .slots = all_block_indices * self .block_size + all_slot_offsets
106
141
# Update the cache
107
- PagedAttention .reshape_and_cache (
108
- key_states ,
109
- value_states ,
110
- self .key_cache [layer_idx ],
111
- self .value_cache [layer_idx ],
112
- self .slots ,
142
+ self .reshape_and_cache (
143
+ key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots
113
144
)
114
145
115
146
# Update the number of seen tokens
@@ -127,7 +158,7 @@ def update_for_decode(
127
158
if layer_idx == 0 :
128
159
start_block_idx = self ._seen_tokens // self .block_size
129
160
slot_offset_in_block = (self ._seen_tokens ) % self .block_size
130
- self .slots = torch .zeros ([batch_size ], device = key_states .device , dtype = torch .int32 )
161
+ self .slots = torch .zeros ([batch_size ], device = self .device , dtype = torch .int32 )
131
162
for i in range (batch_size ):
132
163
if slot_offset_in_block [i ] == 0 :
133
164
# need a new block:
@@ -138,12 +169,8 @@ def update_for_decode(
138
169
self .free_blocks [self .block_tables [i ][b_idx ]] = 0
139
170
self .slots [i ] = self .block_tables [i ][start_block_idx [i ]] * self .block_size + slot_offset_in_block [i ]
140
171
# Update the cache
141
- PagedAttention .reshape_and_cache (
142
- key_states ,
143
- value_states ,
144
- self .key_cache [layer_idx ],
145
- self .value_cache [layer_idx ],
146
- self .slots ,
172
+ self .reshape_and_cache (
173
+ key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots
147
174
)
148
175
149
176
# Update the number of seen tokens
@@ -193,16 +220,15 @@ def get_max_length(self) -> Optional[int]:
193
220
194
221
def reset (self ):
195
222
"""Resets the cache values while preserving the objects"""
196
- self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .block_tables . device )
223
+ self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .device )
197
224
self .block_tables .fill_ (- 1 )
198
- self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .block_tables . device )
225
+ self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .device )
199
226
self .max_seq_len = 0
200
227
201
228
def reorder_cache (self , beam_idx : torch .LongTensor ):
202
229
"""Reorders the cache for beam search, given the selected beam indices."""
203
- device = self .block_tables .device
204
230
origin_table = self .block_tables .clone ()
205
- updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (device ))
231
+ updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (self . device ))
206
232
mask = self .block_tables .masked_fill (self .block_tables != - 1 , 1 ).masked_fill (self .block_tables == - 1 , 0 )
207
233
num_blocks = mask .cumsum (- 1 )[:, - 1 ]
208
234
updated_table = torch .zeros_like (beam_idx )
0 commit comments