1
- from typing import Optional , Tuple
1
+ from typing import List , Optional , Tuple
2
2
3
3
import torch
4
4
from intel_extension_for_pytorch .llm .modules import PagedAttention
@@ -95,14 +95,87 @@ def __init__(
95
95
for _ in range (self .num_hidden_layers )
96
96
]
97
97
98
+ def update_for_prefill (
99
+ self ,
100
+ key_states : torch .Tensor ,
101
+ value_states : torch .Tensor ,
102
+ layer_idx : int ,
103
+ batch_size : int ,
104
+ length_list : Optional [List ],
105
+ ):
106
+ all_block_indices = []
107
+ all_slot_offsets = []
108
+ for i in range (batch_size ):
109
+ num_blocks = (length_list [i ] + self .block_size - 1 ) // self .block_size
110
+ for b_idx in range (num_blocks ):
111
+ if self .block_tables [i ][b_idx ] == - 1 :
112
+ # need a free block
113
+ self .block_tables [i ][b_idx ] = self .free_blocks .pop (0 )
114
+
115
+ slots_range = torch .arange (length_list [i ], device = key_states .device )
116
+ block_indices = slots_range // self .block_size
117
+ slot_offsets = slots_range % self .block_size
118
+ all_block_indices .append (self .block_tables [i ][block_indices ])
119
+ all_slot_offsets .append (slot_offsets )
120
+
121
+ all_block_indices = torch .cat (all_block_indices )
122
+ all_slot_offsets = torch .cat (all_slot_offsets )
123
+ slots_tensor = all_block_indices * self .block_size + all_slot_offsets
124
+ # Update the cache
125
+ PagedAttention .reshape_and_cache (
126
+ key_states ,
127
+ value_states ,
128
+ self .kv_cache [layer_idx ][0 ],
129
+ self .kv_cache [layer_idx ][1 ],
130
+ slots_tensor ,
131
+ )
132
+
133
+ # Update the number of seen tokens
134
+ if layer_idx == self .num_hidden_layers - 1 :
135
+ for i in range (batch_size ):
136
+ self ._seen_tokens [i ] += length_list [i ]
137
+
138
+ def update_for_decode (
139
+ self ,
140
+ key_states : torch .Tensor ,
141
+ value_states : torch .Tensor ,
142
+ layer_idx : int ,
143
+ batch_size : int ,
144
+ ):
145
+ slots = []
146
+ for i in range (batch_size ):
147
+ start_block_idx = self ._seen_tokens [i ] // self .block_size
148
+ num_blocks = (self ._seen_tokens [i ] + self .block_size ) // self .block_size
149
+ for b_idx in range (start_block_idx , num_blocks ):
150
+ if self .block_tables [i ][b_idx ] == - 1 :
151
+ # need a free block
152
+ self .block_tables [i ][b_idx ] = self .free_blocks .pop (0 )
153
+ block_idx = (self ._seen_tokens [i ]) // self .block_size
154
+ slot_offset_in_block = (self ._seen_tokens [i ]) % self .block_size
155
+ slots .append (self .block_tables [i ][block_idx ].item () * self .block_size + slot_offset_in_block )
156
+
157
+ # Update the cache
158
+ PagedAttention .reshape_and_cache (
159
+ key_states ,
160
+ value_states ,
161
+ self .kv_cache [layer_idx ][0 ],
162
+ self .kv_cache [layer_idx ][1 ],
163
+ torch .tensor (slots , device = key_states .device ),
164
+ )
165
+
166
+ # Update the number of seen tokens
167
+ if layer_idx == self .num_hidden_layers - 1 :
168
+ for i in range (batch_size ):
169
+ self ._seen_tokens [i ] += 1
170
+
98
171
def update (
99
172
self ,
100
173
key_states : torch .Tensor ,
101
174
value_states : torch .Tensor ,
102
175
layer_idx : int ,
103
176
attention_mask : torch .Tensor ,
104
177
position_ids : torch .Tensor ,
105
- input_lens : torch . Tensor ,
178
+ length_list : Optional [ List ] ,
106
179
) -> Tuple [torch .Tensor , torch .Tensor ]:
107
180
"""
108
181
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -117,39 +190,14 @@ def update(
117
190
Return:
118
191
A tuple containing the updated key and value states.
119
192
"""
193
+
120
194
batch_size = position_ids .shape [0 ]
121
- slots = []
122
195
if self .get_seq_length () == 0 :
123
196
# prefill
124
- num_slots = input_lens . tolist ( )
197
+ self . update_for_prefill ( key_states , value_states , layer_idx , batch_size , length_list )
125
198
else :
126
199
# decode
127
- num_slots = [1 ] * batch_size
128
- for i in range (batch_size ):
129
- start_block_idx = self ._seen_tokens [i ] // self .block_size
130
- num_blocks = (self ._seen_tokens [i ] + num_slots [i ] + self .block_size - 1 ) // self .block_size
131
- for b_idx in range (start_block_idx , num_blocks ):
132
- if self .block_tables [i ][b_idx ] == - 1 :
133
- # need a free block
134
- self .block_tables [i ][b_idx ] = self .free_blocks .pop (0 )
135
- for slot in range (num_slots [i ]):
136
- block_idx = (self ._seen_tokens [i ] + slot ) // self .block_size
137
- slot_offset_in_block = (self ._seen_tokens [i ] + slot ) % self .block_size
138
- slots .append (self .block_tables [i ][block_idx ].item () * self .block_size + slot_offset_in_block )
139
-
140
- # Update the cache
141
- PagedAttention .reshape_and_cache (
142
- key_states ,
143
- value_states ,
144
- self .kv_cache [layer_idx ][0 ],
145
- self .kv_cache [layer_idx ][1 ],
146
- torch .tensor (slots , device = key_states .device ),
147
- )
148
-
149
- # Update the number of seen tokens
150
- if layer_idx == self .num_hidden_layers - 1 :
151
- for i in range (batch_size ):
152
- self ._seen_tokens [i ] += num_slots [i ]
200
+ self .update_for_decode (key_states , value_states , layer_idx , batch_size )
153
201
154
202
return self .kv_cache [layer_idx ][0 ], self .kv_cache [layer_idx ][1 ]
155
203
0 commit comments