Skip to content

Commit f9866af

Browse files
committed
[GPU] PagedAttention initial impl
1 parent 3ed1868 commit f9866af

File tree

6 files changed

+150
-19
lines changed

6 files changed

+150
-19
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
100100
value_cache_mem, /* value_cache */
101101
instance.input_memory_ptr(7), /* max_context_len */
102102
instance.input_memory_ptr(8), /* context_lens */
103-
instance.input_memory_ptr(9) /* block_tables */ };
103+
instance.input_memory_ptr(9), /* block_tables */
104+
instance.input_memory_ptr(10) /* scale */ };
104105
args.outputs = { instance.output_memory_ptr(0) };
105106
}
106107

@@ -279,20 +280,22 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
279280
static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) {
280281
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);
281282

282-
const auto inputs_count = 6;
283+
const auto inputs_count = 7;
283284
const auto query_layout = impl_param.get_input_layout(0);
284285
const auto key_cache_layout = impl_param.get_input_layout(3);
285286
const auto value_cache_layout = impl_param.get_input_layout(4);
286287
const auto max_context_len_layout = impl_param.get_input_layout(7);
287288
const auto context_lens_layout = impl_param.get_input_layout(8);
288289
const auto block_tables_layout = impl_param.get_input_layout(9);
290+
const auto scale_layout = impl_param.get_input_layout(10);
289291

290292
params.inputs.resize(inputs_count);
291293
params.inputs[1] = convert_data_tensor(key_cache_layout);
292294
params.inputs[2] = convert_data_tensor(value_cache_layout);
293295
params.inputs[3] = convert_data_tensor(max_context_len_layout);
294296
params.inputs[4] = convert_data_tensor(context_lens_layout);
295297
params.inputs[5] = convert_data_tensor(block_tables_layout);
298+
params.inputs[6] = convert_data_tensor(scale_layout);
296299

297300
if (query_layout.is_static() && key_cache_layout.is_static() && value_cache_layout.is_static()) {
298301
// query_shape = [batch_size, seq_len, heads_num * head_size]
@@ -328,6 +331,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
328331
{3, in_offsets_map.at(7)},
329332
{4, in_offsets_map.at(8)},
330333
{5, in_offsets_map.at(9)},
334+
{6, in_offsets_map.at(10)},
331335
};
332336
std::map<size_t, size_t> out_tensor_to_offset_map = {
333337
{0, out_offsets_map.at(0)},

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "kv_cache_inst.h"
3030
#include "condition_inst.h"
3131
#include "gather_inst.h"
32+
#include "paged_attention_inst.h"
3233
#include "experimental_detectron_roi_feature_extractor_inst.hpp"
3334
#include "implementation_map.hpp"
3435
#include "graph_optimizer/prepare_buffer_fusing.h"
@@ -553,6 +554,12 @@ event::ptr primitive_inst::realloc_if_needed() {
553554
}
554555
}
555556

557+
// WA: reallocate memory for PA if previous memory is usm_host used from prefill stage inner model
558+
if (_node->is_type<paged_attention>() && _outputs[0] && _outputs[0]->get_allocation_type() != allocation_type::usm_device) {
559+
GPU_DEBUG_TRACE_DETAIL << id() << " reset memory\n";
560+
_max_output_layout_count = 0;
561+
}
562+
556563
// update layout to ensure that it repsects paddings for correct allocation size
557564
if (_node_output_layout.data_padding.get_dynamic_pad_dims() != tensor(0)) {
558565
size_t rank = updated_layout.get_shape().size();

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ KERNEL(pa_kv_cache_update)(
3232
head_elem_idx * KV_CACHE_BLOCK_SIZE +
3333
block_offset;
3434

35-
// if (INPUT0_FEATURE_NUM == 18 && INPUT0_BATCH_NUM == 2) {
36-
// printf("%d. %d - value\n", out_offset, in_offset);
35+
// if (batch_idx == 0) {
36+
// printf("Update value %d. %d (%f)\n", out_offset, in_offset, value_data[in_offset]);
3737
// }
3838

3939
value_cache_data[out_offset] = value_data[in_offset];
@@ -46,9 +46,9 @@ KERNEL(pa_kv_cache_update)(
4646
block_offset * HEAD_SIZE_BLOCKING +
4747
head_size_outer_block * KV_CACHE_BLOCK_SIZE * HEAD_SIZE_BLOCKING +
4848
head_size_inner_block;
49-
// if (INPUT0_FEATURE_NUM == 18 && INPUT0_BATCH_NUM == 2) {
50-
// printf("%d. %d - key\n", out_offset, in_offset);
49+
// if (batch_idx == 0) {
50+
// printf("Update key_cache %d. %d (%f)\n", out_offset, in_offset, key_data[in_offset]);
5151
// }
52-
value_cache_data[out_offset] = key_data[in_offset];
52+
key_cache_data[out_offset] = key_data[in_offset];
5353
#endif
5454
}

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_ref.cl

+128-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
// constexpr size_t HEAD_SIZE = 64;
1515
// constexpr size_t HEADS_NUM = 32;
1616
// constexpr size_t KV_HEADS_NUM = 4;
17+
// constexpr NUM_QUERIES_PER_KV_HEAD (HEADS_NUM / KV_HEADS_NUM)
1718
// constexpr size_t BLOCK_SIZE = 16;
1819
// constexpr size_t X_SIZE = 4;
1920

@@ -29,14 +30,14 @@
2930
// How much QK outputs each subgroup calculates per cycle
3031
#define QK_PER_SG 4
3132

32-
#define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * HEADS_NUM * BLOCK_SIZE)
33+
#define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * KV_HEADS_NUM * BLOCK_SIZE)
3334

3435
#define QUERY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
3536

3637
#define SUBGROUPS_PER_WG HEAD_SIZE / SUB_GROUP_SIZE
3738

3839
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
39-
__attribute__((reqd_work_group_size(1, 1, SUB_GROUP_SIZE)))
40+
__attribute__((reqd_work_group_size(1, 1, 64)))
4041
KERNEL(pa_sdpa_ref)(
4142
OPTIONAL_SHAPE_INFO_ARG
4243
__global const INPUT0_TYPE* query,
@@ -45,6 +46,7 @@ KERNEL(pa_sdpa_ref)(
4546
__global const INPUT3_TYPE* max_context_len,
4647
__global const INPUT4_TYPE* context_lens,
4748
__global const INPUT5_TYPE* block_tables,
49+
__global const INPUT6_TYPE* scale,
4850
__global OUTPUT_TYPE* output)
4951
{
5052
const uint seq_idx = get_global_id(0);
@@ -60,6 +62,30 @@ KERNEL(pa_sdpa_ref)(
6062

6163
const uint blocks_num = INPUT5_FEATURE_NUM;
6264

65+
// if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
66+
// if (INPUT5_FEATURE_NUM == 0) {
67+
// printf("Empty blocks. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n",
68+
// seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
69+
// } else if (INPUT5_FEATURE_NUM == 1) {
70+
// printf("Blocks table[b=0]: %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0],
71+
// seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
72+
// } else if (INPUT5_FEATURE_NUM == 2) {
73+
// printf("Blocks table[b=0]: %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1],
74+
// seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
75+
// } else if (INPUT5_FEATURE_NUM == 3) {
76+
// printf("Blocks table[b=0]: %d %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1], block_tables[2],
77+
// seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
78+
// } else if (INPUT5_FEATURE_NUM == 4) {
79+
// printf("Blocks table[b=0]: %d %d %d %d. Seq_idx=%d, head_num_idx=%d, head_idx=%d, sglid=%d, sgid=%d, batch_idx=%d, token_idx=%d, context_len=%d, scale=%f\n", block_tables[0], block_tables[1], block_tables[2], block_tables[3],
80+
// seq_idx, head_num_idx, head_idx, sglid, sgid, batch_idx, token_idx, context_len, scale[0]);
81+
// }
82+
83+
// if (seq_idx == 0 && head_num_idx == 0 && sgid == 0 && sglid == 0) {
84+
// printf("key_cache[405504]=%f\n", key_cache[405504]);
85+
// printf("value_cache[405504]=%f\n", value_cache[405504]);
86+
// }
87+
// }
88+
6389
// sgid0: 0..3
6490
// sgid1: 4..7
6591
// sgid2: 8..11
@@ -84,7 +110,9 @@ KERNEL(pa_sdpa_ref)(
84110
OUTPUT_TYPE qk[QK_PER_SG] = {0};
85111

86112
for (uint hs = 0; hs < HEAD_ITEMS_PER_WI; hs++) {
87-
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM + hs * SUB_GROUP_SIZE;
113+
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM +
114+
head_num_idx * HEAD_SIZE +
115+
hs * SUB_GROUP_SIZE;
88116

89117
// TODO: can be preloaded outside HEAD_ITEMS_PER_WI loop - need to check perf
90118
INPUT0_TYPE q = QUERY_BLOCK_READ(query, query_idx);
@@ -94,34 +122,53 @@ KERNEL(pa_sdpa_ref)(
94122
continue;
95123

96124
const uint key_idx = block_offset +
125+
(head_num_idx / NUM_QUERIES_PER_KV_HEAD) * (HEAD_SIZE / X_SIZE * BLOCK_SIZE * X_SIZE) +
97126
(X_SIZE * QK_PER_SG) * sgid +
98127
(HEAD_ITEMS_PER_WI * BLOCK_SIZE * X_SIZE) * hs +
99128
(sglid / X_SIZE) * X_SIZE * BLOCK_SIZE +
100129
(sglid % X_SIZE) + qk_idx * X_SIZE;
130+
101131
// TODO1: try block loading and shuffling
102132
// TODO2: try to load k*4 times and then calculate
103133
// TODO3: try bigger X block
104134
INPUT1_TYPE k = key_cache[key_idx];
105135

136+
137+
// if (seq_idx == 0 && head_num_idx == 0) {
138+
// printf("main_calc: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d, block=%d, hs=%d, qk_idx=%d, current_token=%d, query_idx=%d, key_idx=%d (block_offset=%d): %f * %f\n",
139+
// seq_idx, head_num_idx, sgid, sglid, block, hs, qk_idx, current_token, query_idx, key_idx - block_offset, block_offset, q, k);
140+
// }
141+
106142
qk[qk_idx] = mad(q, k, qk[qk_idx]);
107143
}
108144
}
109145

110-
// Summurize qk calculation across all WIs
146+
// Summurize qk calculation across all WIs and apply scale
111147
for (uint qk_idx = 0; qk_idx < QK_PER_SG; qk_idx++) {
112-
qk[QK_PER_SG] = sub_group_reduce_add(qk[QK_PER_SG]);
113-
qk_max = OUTPUT_MAX_FUNC(qk_max, qk[QK_PER_SG]);
148+
const uint current_token = block * BLOCK_SIZE + sgid * QK_PER_SG + qk_idx;
149+
if (current_token < context_len) {
150+
OUTPUT_TYPE tmp_print = qk[qk_idx];
151+
qk[qk_idx] = sub_group_reduce_add(qk[qk_idx]);
152+
// if (head_num_idx < 4)
153+
// printf("final_calc: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: before qk[%d]=%f, after=%f\n",
154+
// seq_idx, head_num_idx, sgid, sglid, qk_idx, tmp_print, qk[qk_idx]);
155+
qk[qk_idx] = scale[0] * qk[qk_idx];
156+
qk_max = OUTPUT_MAX_FUNC(qk_max, qk[qk_idx]);
157+
}
114158
}
115159

116160
// Save QK results to local memory
117161
if (sglid < QK_PER_SG) {
118-
const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_PER_SG + sglid;
119-
qk_vals[qk_local_idx] = qk[sglid];
162+
const uint current_token = block * BLOCK_SIZE + sgid * QK_PER_SG + sglid;
163+
// Fixed -> // const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_PER_SG + sglid;
164+
// OUTPUT_TYPE tmp_print = (current_token >= context_len ? 0 : qk[sglid]);
165+
// if (head_num_idx < 4 || head_num_idx == 31)
166+
// printf("slm save: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: qk_vals[%d]=%f. Max=%f\n",
167+
// seq_idx, head_num_idx, sgid, sglid, current_token, tmp_print, qk_max);
168+
qk_vals[current_token] = current_token >= context_len ? 0 : qk[sglid];
120169
}
121170
}
122171

123-
/* WARNING NEED TO ADD BIAS BEFORE SOFTMAX */
124-
125172
// Apply SoftMax operation
126173
__local OUTPUT_TYPE qk_max_vals[SUBGROUPS_PER_WG];
127174
__local OUTPUT_TYPE qk_sum_vals[SUBGROUPS_PER_WG];
@@ -138,23 +185,35 @@ KERNEL(pa_sdpa_ref)(
138185
// Final max value after reduction across of all SG and WI
139186
qk_max = sub_group_reduce_max(qk_max);
140187

188+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
189+
// printf("QK max value = %f\n", qk_max);
190+
// }
191+
141192
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO;
142193
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
143194
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
144195
if (data_idx < context_len) {
145196
OUTPUT_TYPE val = native_exp(qk_vals[data_idx] - qk_max);
146197
exp_sum += val;
147198
qk_vals[data_idx] = val;
199+
// if (head_num_idx < 4 || head_num_idx == 31)
200+
// printf("head_num %d, sgid = %d, sglid = %d, exp_sum = %f\n", head_num_idx, sgid, sglid, exp_sum);
148201
}
149202
}
150203

151204
exp_sum = sub_group_reduce_add(exp_sum);
152205

206+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
207+
// printf("exp_sum final value = %f\n", exp_sum);
208+
// }
209+
153210
if (sglid == 0)
154211
qk_sum_vals[sgid] = exp_sum;
155212

156213
barrier(CLK_LOCAL_MEM_FENCE);
157214

215+
exp_sum = OUTPUT_VAL_ZERO;
216+
158217
if (sglid < SUBGROUPS_PER_WG)
159218
exp_sum = qk_sum_vals[sglid];
160219

@@ -163,6 +222,8 @@ KERNEL(pa_sdpa_ref)(
163222

164223
const OUTPUT_TYPE inv_sum = OUTPUT_VAL_ONE / exp_sum;
165224

225+
226+
// TODO: replace CEIL_DIV with ALIGN and use += SUBGROUPS_PER_WG * SUB_GROUP_SIZE increment
166227
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
167228
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
168229
if (data_idx < context_len) {
@@ -174,5 +235,61 @@ KERNEL(pa_sdpa_ref)(
174235
barrier(CLK_LOCAL_MEM_FENCE);
175236
}
176237

177-
output[seq_idx + sglid] = qk_vals[sglid % context_len];
238+
// if (seq_idx == 0 && sgid == 0 && sglid == 0) {
239+
// for (uint i = 0; i < context_len; i++) {
240+
// printf("Softmax res for %d head: %d. %f\n", head_num_idx, i, qk_vals[i]);
241+
// }
242+
// }
243+
244+
{
245+
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
246+
247+
for (uint qk_idx = 0; qk_idx < ALIGN(context_len, SUB_GROUP_SIZE); qk_idx += SUB_GROUP_SIZE) {
248+
const uint qk_offset = qk_idx + sglid;
249+
250+
OUTPUT_TYPE qk = qk_offset < context_len ? qk_vals[qk_offset] : OUTPUT_VAL_ZERO;
251+
252+
const uint block_idx = block_tables[batch_idx * blocks_num + (qk_idx / SUB_GROUP_SIZE)];
253+
if (block_idx == 0)
254+
continue;
255+
256+
const uint value_cache_offset = block_idx * KV_CACHE_BLOCK_STRIDE +
257+
(head_num_idx / NUM_QUERIES_PER_KV_HEAD) * (HEAD_SIZE * BLOCK_SIZE) +
258+
sgid * (SUB_GROUP_SIZE * BLOCK_SIZE) +
259+
sglid * BLOCK_SIZE;
260+
261+
#define VALUE_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, BLOCK_SIZE)
262+
#define VALUE_VLOAD(offset, ptr) CAT(vload, BLOCK_SIZE)(offset, ptr)
263+
264+
ushort16 v_tmp = vload16(0, (__global ushort*)(value_cache + value_cache_offset));
265+
OUTPUT_TYPE* v = (OUTPUT_TYPE*)&v_tmp;
266+
267+
// VALUE_VEC_TYPE* tmp_print = v;
268+
269+
// if (seq_idx == 0 && head_num_idx == 0) {
270+
// printf("gemm2: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d, block_idx=%d, qk_idx=%d, qk_offset=%d, value_offset=%d (block_offset=%d): %v8f\n",
271+
// seq_idx, head_num_idx, sgid, sglid, block_idx, qk_idx, qk_offset, value_cache_offset - (block_idx * KV_CACHE_BLOCK_STRIDE), block_idx * KV_CACHE_BLOCK_STRIDE, *tmp_print);
272+
// }
273+
274+
for (uint token = 0; token < BLOCK_SIZE; token++) {
275+
OUTPUT_TYPE qk_tmp = sub_group_broadcast(qk, token);
276+
if (qk_idx + token < context_len) {
277+
acc = mad(qk_tmp, v[token], acc);
278+
}
279+
}
280+
}
281+
282+
283+
const uint output_offset = seq_idx * (HEADS_NUM * HEAD_SIZE) +
284+
head_num_idx * HEAD_SIZE +
285+
sgid * SUB_GROUP_SIZE +
286+
sglid;
287+
288+
// if (seq_idx == 0 && head_num_idx < 2 || head_num_idx == 31) {
289+
// printf("output res: seq_idx=%d, head_num_idx=%d, sgid=%d, sglid=%d: output[%d] = %f\n",
290+
// seq_idx, head_num_idx, sgid, sglid, output_offset, acc);
291+
// }
292+
293+
output[output_offset] = acc;
294+
}
178295
}

src/plugins/intel_gpu/src/kernel_selector/kernels/paged_attention/kv_cache_update_kernel_ref.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ bool KVCacheUpdateKernelRef::Validate(const Params& params) const {
117117
JitConstants KVCacheUpdateKernelRef::GetJitConstants(const kv_cache_update_params& kernel_params, KernelMode mode) const {
118118
JitConstants jit = MakeBaseParamsJitConstants(kernel_params);
119119

120+
GPU_DEBUG_TRACE << "Configure kernel for " << static_cast<int>(mode) << "\n";
121+
120122
if (mode == KernelMode::key_cache_update)
121123
jit.AddConstant(MakeJitConstant("KEY_CACHE_UPDATE", 1));
122124
else

src/plugins/intel_gpu/src/kernel_selector/kernels/paged_attention/sdpa_kernel_ref.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ JitConstants SDPAKernelRef::GetJitConstants(const sdpa_params& kernel_params) co
119119
jit.AddConstant(MakeJitConstant("HEAD_SIZE", HEAD_SIZE));
120120
jit.AddConstant(MakeJitConstant("HEADS_NUM", HEADS_NUM));
121121
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", KV_HEADS_NUM));
122+
jit.AddConstant(MakeJitConstant("NUM_QUERIES_PER_KV_HEAD", HEADS_NUM / KV_HEADS_NUM));
122123
jit.AddConstant(MakeJitConstant("BLOCK_SIZE", BLOCK_SIZE));
123124
jit.AddConstant(MakeJitConstant("X_SIZE", X_SIZE));
124125

@@ -140,7 +141,7 @@ CommonDispatchData SDPAKernelRef::SetDefault(const sdpa_params& kernel_params) {
140141
dispatch_data.gws = { tokens_num,
141142
kernel_params.configuration.heads_num,
142143
kernel_params.configuration.head_size };
143-
dispatch_data.lws = { 1, 1, 16 };
144+
dispatch_data.lws = { 1, 1, kernel_params.configuration.head_size };
144145
}
145146

146147
return dispatch_data;

0 commit comments

Comments
 (0)