Skip to content

Commit f02a27f

Browse files
committed
Fix micro_sdpa kernel arguments
1 parent e64244f commit f02a27f

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

+11-4
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,12 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
152152
#if WITH_SCALE
153153
global SCALE_DATA_T *scale_ptr,
154154
#endif
155-
int d,
156155
#if IS_PAGED_ATTENTION
157156
const __global int* blocked_indexes_start,
158157
const __global int* blocked_indexes_end,
159158
const __global int* gws_seq_indexes_correspondence
160159
#else
160+
int d,
161161
int k,
162162
int q
163163
#endif
@@ -172,11 +172,18 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
172172
const uint q_tile_idx = get_group_id(0);
173173
const uint block_start_pos = blocked_indexes_start[q_tile_idx];
174174
const uint block_end_pos = blocked_indexes_end[q_tile_idx];
175-
const uint subsequence_q_tile_idx = block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
175+
const uint gws_mapping = gws_seq_indexes_correspondence[q_tile_idx];
176+
const uint subsequence_begin = subsequence_begins[gws_mapping];
177+
const uint subsequence_begin_next = subsequence_begins[gws_mapping + 1];
178+
const uint subsequence_q_tile_idx = block_start_pos - subsequence_begin;
176179
// const uint sequence_idx_end = block_end_pos - block_start_pos;
177-
const uint subsequence_begin = subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
178-
const int k = subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx] + 1] - subsequence_begins[gws_seq_indexes_correspondence[q_tile_idx]];
180+
const int k = subsequence_begins[gws_mapping + 1] - subsequence_begin;
179181
const int q = k;
182+
const int d = HEAD_SIZE;
183+
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
184+
printf("q_tile_idx=%d block_start_pos=%d block_end_pos=%d gws_mapping=%d subsequence_begin=%d subsequence_begin_next=%d subsequence_q_tile_idx=%d k=%d d=%d\n",
185+
q_tile_idx, block_start_pos, block_end_pos, gws_mapping, subsequence_begin, subsequence_begin_next, subsequence_q_tile_idx, k, d);
186+
}
180187
#endif
181188
uint sg_ij = sub_group_broadcast(get_local_id(1), 0);
182189
uint b0 = get_group_id(1);

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
508508
print_arr(q_dims, q_dims.size(), "q_dims");
509509
print_arr(k_dims, k_dims.size(), "k_dims");
510510
print_arr(v_dims, v_dims.size(), "v_dims");
511-
std::cout << Q_num_heads_dim.is_dynamic << " " << K_num_heads_dim.is_dynamic << " " << V_num_heads_dim.is_dynamic << " " << K_num_heads_dim.v << " " << V_num_heads_dim.v << "\n";
511+
std::cout << Q_num_heads_dim.is_dynamic << " "
512+
<< K_num_heads_dim.is_dynamic << " "
513+
<< V_num_heads_dim.is_dynamic << " " << K_num_heads_dim.v << " " << V_num_heads_dim.v << "\n";
512514
return false;
513515
}
514516

@@ -555,6 +557,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
555557
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size(prim_params.engineInfo.arch)));
556558
jit.AddConstant(MakeJitConstant("INVERT_SCALE", false));
557559
jit.AddConstant(MakeJitConstant("SCALE_DATA_T", "half"));
560+
jit.AddConstant(MakeJitConstant("HEAD_SIZE", head_size));
558561

559562
jit.AddConstant(MakeJitConstant("WITH_ATTN_MASK", sdpa_inputs > 3));
560563
jit.AddConstant(MakeJitConstant("WITH_SCALE", sdpa_inputs > 4));
@@ -733,6 +736,8 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params, const
733736
auto seq_length = get_seq_length(params.inputs[0], params.input0_order).v;
734737
if (params.conf.is_paged_attention) {
735738
seq_length = params.conf.paged_attention_aligned_seq_len;
739+
GPU_DEBUG_TRACE_DETAIL << "seq_len=" << seq_length << "\n";
740+
GPU_DEBUG_TRACE_DETAIL << "wg_tile_q=" << wg_tile_q << "\n";
736741
}
737742

738743
dispatch_data.gws[0] *= CeilDiv(seq_length, wg_tile_q);
@@ -767,6 +772,14 @@ clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, bool is
767772
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2}); // V
768773
kernel.params.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // A
769774

775+
776+
if (params.conf.is_paged_attention) {
777+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 3}); // subsequence_begins
778+
779+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
780+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
781+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
782+
}
770783
// if (params.inputs.size() >= 4)
771784
// kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 3}); // mask
772785
// if (params.inputs.size() >= 5)

0 commit comments

Comments
 (0)