Skip to content

Commit 368923f

Browse files
committed
Fix accuracy issue, remove is_prompt check
1 parent c19e344 commit 368923f

File tree

4 files changed

+151
-99
lines changed

4 files changed

+151
-99
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/multi_stage_primitive.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
5353
}
5454
this->can_reuse_memory = false;
5555
this->_kernel_name = other._kernel_name;
56+
this->can_reuse_memory = other.can_reuse_memory;
5657
this->_is_dynamic = other._is_dynamic;
5758
}
5859

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
4242

4343
void paged_attention_inst::update_shape_info_tensor(const kernel_impl_params& params) {
4444
auto& service_stream = this->get_network().get_engine().get_service_stream();
45-
auto is_prefill_memory = this->input_memory_ptr(5);
46-
mem_lock<uint8_t, mem_lock_type::read> is_prefill_memory_lock(is_prefill_memory, service_stream);
47-
bool is_prefill_stage = is_prefill_memory_lock[0];
48-
is_prefill_stage = false;
45+
// auto is_prefill_memory = this->input_memory_ptr(5);
46+
// mem_lock<uint8_t, mem_lock_type::read> is_prefill_memory_lock(is_prefill_memory, service_stream);
47+
// bool is_prefill_stage = is_prefill_memory_lock[0];
48+
bool is_prefill_stage = false;
4949
if (!is_prefill_stage) {
5050
parent::update_shape_info_tensor(params);
5151
} else {

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_ref.cl

+34-19
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ KERNEL(pa_sdpa_ref)(
5858
__global OUTPUT_TYPE* tmp_out,
5959
const uint num_of_portions
6060
#else
61-
__global OUTPUT_TYPE* output,
61+
__global OUTPUT_TYPE* output
6262
#endif
63-
)
64-
{
63+
) {
6564
const uint seq_idx = get_global_id(0);
6665
const uint head_num_idx = get_global_id(1);
6766
const uint head_idx = get_global_id(2);
@@ -73,7 +72,7 @@ KERNEL(pa_sdpa_ref)(
7372

7473
const uint context_len = context_lens[batch_idx];
7574

76-
const uint total_blocks_num = INPUT5_FEATURE_NUM;
75+
const uint blocks_pitch = INPUT5_FEATURE_NUM;
7776

7877
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
7978
const uint portion_id = get_group_id(2);
@@ -86,6 +85,8 @@ KERNEL(pa_sdpa_ref)(
8685
const uint block_start_idx = 0;
8786
#endif
8887

88+
const uint total_blocks_num = CEIL_DIV(context_len, BLOCK_SIZE);
89+
8990
// if (seq_idx < 2 && head_num_idx < 2 && sgid < 2 && sglid < 2) {
9091
// if (INPUT5_BATCH_NUM == 2) {
9192
// if (INPUT5_FEATURE_NUM == 0) {
@@ -159,12 +160,13 @@ KERNEL(pa_sdpa_ref)(
159160

160161
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
161162
// FINAL: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
162-
const uint blocks_num = SEQ_LEN_PORTION_SIZE / BLOCK_SIZE;
163+
const uint blocks_num = (portion_id == num_of_portions - 1) ? (total_blocks_num - (portion_id * SEQ_LEN_PORTION_SIZE / BLOCK_SIZE))
164+
: (SEQ_LEN_PORTION_SIZE / BLOCK_SIZE);
163165
#else
164166
const uint blocks_num = total_blocks_num;
165167
#endif
166168
for (uint block_num = 0; block_num < blocks_num; block_num++) {
167-
const uint block_idx = batch_idx * total_blocks_num + block_start_idx + block_num;
169+
const uint block_idx = batch_idx * blocks_pitch + block_start_idx + block_num;
168170
const uint block_offset = block_tables[block_idx] * KV_CACHE_BLOCK_STRIDE;
169171

170172
OUTPUT_TYPE qk[QK_VALS_PER_SG_PER_ITER] = {0};
@@ -263,8 +265,9 @@ KERNEL(pa_sdpa_ref)(
263265
ulong timer_end = intel_get_cycle_counter();
264266
ulong total_time = timer_end - timer_start;
265267

266-
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
267-
// printf("SDPA kernel GEMM1: %d; qk_max=%f\n", (uint)total_time, qk_max);
268+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id(2) == 0 && context_len >= 496)
269+
// printf("%d. %d. SDPA kernel GEMM1: %d; qk_max=%f, blocks_num=%d, total_blocks_num=%d, portion_id=%d, num_of_portions=%d\n",
270+
// context_len, get_global_id(2), (uint)total_time, qk_max, blocks_num, total_blocks_num, portion_id, num_of_portions);
268271
}
269272

270273
// barrier(CLK_LOCAL_MEM_FENCE);
@@ -311,10 +314,12 @@ KERNEL(pa_sdpa_ref)(
311314

312315
// // temp test
313316
// barrier(CLK_LOCAL_MEM_FENCE);
317+
ulong timer_start2 = intel_get_cycle_counter();
314318

315319
ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO;
316320
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
317-
const uint qk_num = CEIL_DIV(SEQ_LEN_PORTION_SIZE, SUBGROUPS_PER_WG * SUB_GROUP_SIZE);
321+
const uint qk_num = (num_of_portions == 1) ? CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE)
322+
: CEIL_DIV(SEQ_LEN_PORTION_SIZE, SUBGROUPS_PER_WG * SUB_GROUP_SIZE);
318323
#else
319324
const uint qk_num = CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE);
320325
#endif
@@ -338,6 +343,7 @@ KERNEL(pa_sdpa_ref)(
338343
}
339344
}
340345

346+
ulong timer_start3 = intel_get_cycle_counter();
341347

342348
// // temp test
343349
// barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,6 +371,7 @@ KERNEL(pa_sdpa_ref)(
365371

366372
exp_sum = ACCUMULATOR_VAL_ZERO;
367373

374+
ulong timer_start4 = intel_get_cycle_counter();
368375

369376
// FINAL FIX: Compile time restiction SUBGROUPS_PER_WG <= SG_SIZE
370377
if (sglid < SUBGROUPS_PER_WG)
@@ -391,9 +398,7 @@ KERNEL(pa_sdpa_ref)(
391398
}
392399

393400
barrier(CLK_LOCAL_MEM_FENCE);
394-
395-
ulong timer_end = intel_get_cycle_counter();
396-
ulong total_time = timer_end - timer_start;
401+
ulong timer_start5 = intel_get_cycle_counter();
397402

398403
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
399404
{
@@ -417,8 +422,17 @@ KERNEL(pa_sdpa_ref)(
417422
}
418423
#endif
419424

420-
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
421-
// printf("SDPA kernel Softmax: %d\n", (uint)total_time);
425+
ulong timer_end = intel_get_cycle_counter();
426+
427+
ulong total_time1 = timer_start2 - timer_start;
428+
ulong total_time2 = timer_start3 - timer_start2;
429+
ulong total_time3 = timer_start4 - timer_start3;
430+
ulong total_time4 = timer_start5 - timer_start4;
431+
ulong total_time5 = timer_end - timer_start5;
432+
433+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id(2) == 0)
434+
// printf("%d. SDPA kernel Softmax: qk_max calc: %d, exp_sum_loc calc: %d, exp_sum calc: %d, qk_vals recalc: %d, save: %d\n",
435+
// get_global_id(2), (uint)total_time1, (uint)total_time2, (uint)total_time3, (uint)total_time4, (uint)total_time5);
422436
}
423437

424438
// if (seq_idx == 0 && sgid == 0 && sglid == 0) {
@@ -433,17 +447,18 @@ KERNEL(pa_sdpa_ref)(
433447

434448
#ifdef USE_SPLIT_ACROSS_SEQ_LEN
435449
// FINAL: Compile time restriction: devisible SEQ_LEN_PORTION_SIZE / BLOCK_SIZE
436-
const uint qk_num = SEQ_LEN_PORTION_SIZE / BLOCK_SIZE * SUB_GROUP_SIZE;
450+
const uint qk_num = (portion_id == num_of_portions - 1) ? (context_len - (portion_id * SEQ_LEN_PORTION_SIZE))
451+
: (SEQ_LEN_PORTION_SIZE);
437452
#else
438-
const uint qk_num = ALIGN(context_len, SUB_GROUP_SIZE);
453+
const uint qk_num = context_len;
439454
#endif
440455
for (uint qk_idx = 0; qk_idx < qk_num; qk_idx += SUB_GROUP_SIZE) {
441456
const uint qk_offset_local = qk_idx + sglid;
442457
const uint qk_offset_global = block_start_idx * BLOCK_SIZE + qk_offset_local;
443458

444459
OUTPUT_TYPE qk = qk_offset_global < context_len ? qk_vals[qk_offset_local] : OUTPUT_VAL_ZERO;
445460

446-
const uint block_idx = block_tables[batch_idx * total_blocks_num + block_start_idx + (qk_idx / BLOCK_SIZE)];
461+
const uint block_idx = block_tables[batch_idx * blocks_pitch + block_start_idx + (qk_idx / BLOCK_SIZE)];
447462
// if (block_idx == 0)
448463
// continue;
449464

@@ -504,8 +519,8 @@ KERNEL(pa_sdpa_ref)(
504519
ulong timer_end = intel_get_cycle_counter();
505520
ulong total_time = timer_end - timer_start;
506521

507-
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0)
508-
// printf("SDPA kernel GEMM2: %d\n", (uint)total_time);
522+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_local_id(2) == 0)
523+
// printf("%d. SDPA kernel GEMM2: %d\n", get_global_id(2), (uint)total_time);
509524
}
510525
}
511526

0 commit comments

Comments
 (0)