@@ -73,10 +73,9 @@ def get_kv_cache_shape(
73
73
num_blocks : int ,
74
74
block_size : int ,
75
75
num_kv_heads : int ,
76
- kv_lora_rank : int ,
76
+ head_size : int ,
77
77
) -> Tuple [int , ...]:
78
- k_pe_size = kv_lora_rank // 8
79
- return (num_blocks , block_size , kv_lora_rank + k_pe_size ), True
78
+ return (num_blocks , block_size , head_size ), (num_blocks , block_size , head_size // 9 * 8 )
80
79
81
80
@staticmethod
82
81
def get_impl_cls () -> Type ["HPUAttentionImpl" ]:
@@ -137,7 +136,8 @@ def __init__(
137
136
self .matmul_av = Matmul ()
138
137
self .batch2block_matmul = Matmul ()
139
138
self .block2batch_matmul = Matmul ()
140
- self .latent_cache = VLLMKVCache ()
139
+ self .latent_cache_k = VLLMKVCache ()
140
+ self .latent_cache_v = VLLMKVCache ()
141
141
HPUFusedSDPA = kernels .fsdpa ()
142
142
self .fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
143
143
else ModuleFusedSDPA (HPUFusedSDPA )
@@ -186,9 +186,6 @@ def forward(
186
186
q_pe = torch .matmul (hidden_states_or_q_c , self .W_QR )\
187
187
.view (- 1 , self .num_heads , self .qk_rope_head_dim )
188
188
input_positions = attn_metadata .input_positions .view (- 1 )
189
- print ("q_pe" , q_pe .shape )
190
- print ("k_pe" , k_pe .shape )
191
- print ("input_positions" , attn_metadata .input_positions .shape )
192
189
q_pe , k_pe = \
193
190
self .rotary_emb (input_positions , q_pe , k_pe )
194
191
else :
@@ -197,9 +194,6 @@ def forward(
197
194
198
195
q_pe = q [..., self .qk_nope_head_dim :]
199
196
200
- # print("q_pe shape", q_pe.shape)
201
- # print("k_pe shape", k_pe.shape)
202
- # print("input_positions shape", attn_metadata.input_positions.shape)
203
197
input_positions = attn_metadata .input_positions .view (- 1 )
204
198
# TODO(lucas): there must be a nicer way to write this line
205
199
q [..., self .qk_nope_head_dim :], k_pe = \
@@ -208,15 +202,29 @@ def forward(
208
202
block_indices = attn_metadata .block_indices
209
203
block_offsets = attn_metadata .block_offsets
210
204
211
- latent_vec = torch .concat (
205
+ latent_vec_k = torch .concat (
212
206
(k_c_normed , k_pe .view (batch_size , - 1 , self .qk_rope_head_dim )), dim = - 1 )
213
207
# assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}"
214
- # print(f"layer._k_scale={layer._k_scale}")
208
+ latent_vec_k = latent_vec_k .view (- 1 , self .qk_rope_head_dim + self .kv_lora_rank )
209
+ latent_vec_v = k_c_normed .view (- 1 , self .kv_lora_rank )
210
+ if is_prefill :
211
+ latent_vec_k = latent_vec_k .unflatten (0 , (block_indices .size (0 ), - 1 ))
212
+ latent_vec_v = latent_vec_v .unflatten (0 , (block_indices .size (0 ), - 1 ))
213
+ # print("latent_vec", latent_vec.shape)
214
+
215
215
216
216
# write the latent and rope to kv cache
217
- if kv_cache is not None :
218
- kv_cache = self .latent_cache (latent_vec , kv_cache , block_indices ,
217
+ if kv_cache is not None and len (kv_cache ) == 2 :
218
+ # print(f"k cache shape: {kv_cache[0].shape}")
219
+ # print(f"v cache shape: {kv_cache[1].shape}")
220
+ # print(f"latent vec k shape: {latent_vec_k.shape}")
221
+ # print(f"latent vec v shape: {latent_vec_v.shape}")
222
+
223
+ k_cache = self .latent_cache_k (latent_vec_k , kv_cache [0 ], block_indices ,
219
224
block_offsets )
225
+ v_cache = self .latent_cache_v (latent_vec_v , kv_cache [1 ], block_indices ,
226
+ block_offsets )
227
+ kv_cache = (k_cache , v_cache )
220
228
221
229
if is_prefill :
222
230
return self ._forward_prefill (q , k_c_normed , k_pe , attn_metadata , batch_size )
@@ -268,20 +276,14 @@ def _forward_decode(
268
276
self ,
269
277
q_nope : torch .Tensor ,
270
278
q_pe : torch .Tensor ,
271
- kv_c_and_k_pe_cache : torch .Tensor ,
279
+ kv_cache : torch .Tensor ,
272
280
attn_metadata : HPUAttentionMetadata ,
273
281
batch_size : int
274
282
) -> torch .Tensor :
275
- print (f"q_nope shape: { q_nope .shape } " )
276
- print (f"q_pe shape: { q_pe .shape } " )
277
-
278
283
q = torch .cat ([q_nope , q_pe ], dim = - 1 )
279
- kv_c_and_k_pe_cache = kv_c_and_k_pe_cache .unsqueeze (2 )
280
- kv_c_cache = kv_c_and_k_pe_cache [..., : self . kv_lora_rank ]
284
+ kv_c_and_k_pe_cache = kv_cache [ 0 ] .unsqueeze (2 )
285
+ kv_c_cache = kv_cache [ 1 ]. unsqueeze ( 2 )
281
286
282
- print (f"q shape: { q .shape } " )
283
- print (f"kv_c_and_k_pe_cache shape: { kv_c_and_k_pe_cache .shape } " )
284
- print (f"kv_c_cache shape: { kv_c_cache .shape } " )
285
287
output = HPUPagedAttention .forward_decode (
286
288
query = q ,
287
289
key_cache = kv_c_and_k_pe_cache ,
@@ -296,13 +298,11 @@ def _forward_decode(
296
298
matmul_av_op = self .matmul_av ,
297
299
batch2block_matmul_op = self .batch2block_matmul ,
298
300
block2batch_matmul_op = self .block2batch_matmul ,
299
- keys_fetch_func = self .latent_cache .fetch_from_cache ,
300
- values_fetch_func = self .latent_cache .fetch_from_cache )
301
+ keys_fetch_func = self .latent_cache_k .fetch_from_cache ,
302
+ values_fetch_func = self .latent_cache_v .fetch_from_cache )
301
303
output = output .view (batch_size , 1 , - 1 )
302
- print ("output" , output .shape )
303
304
result = self ._v_up_proj_and_o_proj (output )
304
305
result = result .view (batch_size , 1 , - 1 )
305
- print ("result" , result .shape )
306
306
return result
307
307
308
308
0 commit comments