Skip to content

Commit 285b268

Browse files
committed
SDPA opt tests
1 parent 79cc1a4 commit 285b268

File tree

4 files changed

+170
-113
lines changed

4 files changed

+170
-113
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+17-17
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,11 @@ KERNEL(sdpa_opt)(
412412
}
413413
#endif
414414

415-
ulong timer_start2 = intel_get_cycle_counter();
416-
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
417-
ulong diff = timer_start2 - timer_start1;
418-
printf("Gemm1 time: %lu\n", diff);
419-
}
415+
// ulong timer_start2 = intel_get_cycle_counter();
416+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
417+
// ulong diff = timer_start2 - timer_start1;
418+
// printf("Gemm1 time: %lu\n", diff);
419+
// }
420420

421421
} // finish Gemm1
422422

@@ -519,11 +519,11 @@ KERNEL(sdpa_opt)(
519519
}
520520

521521

522-
ulong timer_start2 = intel_get_cycle_counter();
523-
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
524-
ulong diff = timer_start2 - timer_start1;
525-
printf("%d. Softmax time: %lu\n", 0, diff);
526-
}
522+
// ulong timer_start2 = intel_get_cycle_counter();
523+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
524+
// ulong diff = timer_start2 - timer_start1;
525+
// printf("%d. Softmax time: %lu\n", 0, diff);
526+
// }
527527

528528
}
529529

@@ -533,7 +533,7 @@ KERNEL(sdpa_opt)(
533533

534534
OUTPUT_TYPE acc[SEQ_ID_BLOCK_SIZE] = {OUTPUT_VAL_ZERO};
535535
for (uint seq_len = 0; seq_len < partition_seq_len / SUBGROUP_SIZE; seq_len++) {
536-
uint value_offset = INPUT1_GET_INDEX(batch_idx, head_num_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx);
536+
uint value_offset = INPUT2_GET_INDEX(batch_idx, head_num_idx, start_partition_idx + (seq_len * SUBGROUP_SIZE), head_size_idx);
537537

538538
OUTPUT_TYPE qk_val[SEQ_ID_BLOCK_SIZE];
539539
unroll_for (uint seq_idx_index = 0; seq_idx_index < SEQ_ID_BLOCK_SIZE; seq_idx_index++) {
@@ -555,7 +555,7 @@ KERNEL(sdpa_opt)(
555555
/* TODO: Remove if */
556556
if (seq_len_leftover_start != partition_seq_len) {
557557
for (uint seq_len = seq_len_leftover_start; seq_len < partition_seq_len; seq_len++) {
558-
const uint value_offset = INPUT1_GET_INDEX(batch_idx, head_num_idx, start_partition_idx + seq_len, head_size_idx);
558+
const uint value_offset = INPUT2_GET_INDEX(batch_idx, head_num_idx, start_partition_idx + seq_len, head_size_idx);
559559

560560
OUTPUT_TYPE qk_val[SEQ_ID_BLOCK_SIZE];
561561
unroll_for (uint seq_idx_index = 0; seq_idx_index < SEQ_ID_BLOCK_SIZE; seq_idx_index++) {
@@ -599,11 +599,11 @@ KERNEL(sdpa_opt)(
599599
}
600600

601601

602-
ulong timer_start2 = intel_get_cycle_counter();
603-
if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
604-
ulong diff = timer_start2 - timer_start1;
605-
printf("Gemm2 time: %lu\n", diff);
606-
}
602+
// ulong timer_start2 = intel_get_cycle_counter();
603+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
604+
// ulong diff = timer_start2 - timer_start1;
605+
// printf("Gemm2 time: %lu\n", diff);
606+
// }
607607
}
608608
}
609609

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp

+51-47
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ std::string convert_to(const std::string &str) {
2626

2727
static size_t get_seq_id_block_size() {
2828
static bool called = false;
29-
size_t block_size = 1;
29+
size_t block_size = 8;
3030
if (const auto env_var = std::getenv("BLOCK_SIZE")) {
3131
block_size = convert_to<size_t>(env_var);
3232
}
@@ -53,6 +53,21 @@ static size_t get_seq_len_partition_size() {
5353
return seq_len;
5454
}
5555

56+
static size_t get_mul_num() {
57+
static bool called = false;
58+
size_t muls_num = 8;
59+
if (const auto env_var = std::getenv("MULS_NUM")) {
60+
muls_num = convert_to<size_t>(env_var);
61+
}
62+
63+
if (!called) {
64+
std::cout << "Set muls_num = " << muls_num << "\n";
65+
called = true;
66+
}
67+
return muls_num;
68+
}
69+
70+
5671
ParamsKey SDPAKernelOpt::GetSupportedKey() const {
5772
ParamsKey k;
5873
k.EnableInputDataType(Datatype::F16);
@@ -96,35 +111,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
96111
const auto softmax_acc_dt = params.inputs[0].GetDType();
97112
jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR"));
98113

99-
if (const auto env_var = std::getenv("MULS_NUM")) {
100-
auto muls_num = convert_to<size_t>(env_var);
101-
jit.AddConstant(MakeJitConstant("MULS_NUM", muls_num));
102-
}
103-
104-
if (const auto env_var = std::getenv("FIRST_APPROACH_OPTION2")) {
105-
jit.AddConstant(MakeJitConstant("FIRST_APPROACH_OPTION2", 1));
106-
}
107-
108-
if (const auto env_var = std::getenv("FIRST_APPROACH")) {
109-
jit.AddConstant(MakeJitConstant("FIRST_APPROACH", 1));
110-
}
111-
112-
static bool printed = false;
113-
if (!printed) {
114-
printed = true;
115-
std::cout << "Input " << static_cast<int>(params.inputs[0].GetDType()) << " " << static_cast<int>(params.outputs[0].GetDType()) << " " << static_cast<int>(softmax_acc_dt) << "\n";
116-
117-
if (const auto env_var = std::getenv("MULS_NUM")) {
118-
auto muls_num = convert_to<size_t>(env_var);
119-
std::cout << "Force MULS_NUM to " << muls_num << "\n";
120-
}
121-
if (const auto env_var = std::getenv("FIRST_APPROACH_OPTION2")) {
122-
std::cout << "Force FIRST_APPROACH_OPTION2\n";
123-
}
124-
if (const auto env_var = std::getenv("FIRST_APPROACH")) {
125-
std::cout << "Force FIRST_APPROACH\n";
126-
}
127-
}
114+
jit.AddConstant(MakeJitConstant("MULS_NUM", get_mul_num()));
128115

129116
const auto& config = params.conf;
130117
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
@@ -135,9 +122,17 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
135122
jit.AddConstant(MakeJitConstant("USE_SEQ_LEN_SPLIT", 1));
136123
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size()));
137124
jit.AddConstant(MakeJitConstant("SLM_SIZE", get_seq_len_partition_size()));
138-
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(kernel_idx), 1));
139125

140-
jit.AddConstant(MakeJitConstant("SEQ_ID_BLOCK_SIZE", get_seq_id_block_size()));
126+
// kernel_idx == 0 - single token opt
127+
// kernel_idx == 1 - multi token opt
128+
// kernel_idx == 2 - finalization
129+
if (kernel_idx == 0)
130+
jit.AddConstant(MakeJitConstant("SEQ_ID_BLOCK_SIZE", 1));
131+
else
132+
jit.AddConstant(MakeJitConstant("SEQ_ID_BLOCK_SIZE", get_seq_id_block_size()));
133+
134+
auto sdpa_stage = kernel_idx == 2 ? 1 : 0;
135+
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1));
141136

142137
return jit;
143138
}
@@ -153,10 +148,11 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
153148
const size_t target_seq_len = output.Y().v;
154149
const size_t num_of_partitions = CeilDiv(source_seq_len, get_seq_len_partition_size());
155150
const size_t head_size = static_cast<size_t>(params.conf.head_size);
151+
const size_t block_size = kernel_idx == 1 ? get_seq_id_block_size() : 1;
156152

157-
if (kernel_idx == 0) {
153+
if (kernel_idx == 0 || kernel_idx == 1) {
158154
dispatch_data.gws = { output.Batch().v * output.Feature().v,
159-
CeilDiv(target_seq_len, get_seq_id_block_size()),
155+
CeilDiv(target_seq_len, block_size),
160156
head_size * num_of_partitions };
161157
dispatch_data.lws = { 1, 1, head_size };
162158
} else {
@@ -175,7 +171,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
175171
return {};
176172
}
177173

178-
const size_t kernels_num = 2;
174+
const size_t kernels_num = 3;
179175
KernelData kd = KernelData::Default<sdpa_params>(params, kernels_num);
180176
kd.needs_sub_kernels_sync = true;
181177

@@ -184,14 +180,15 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
184180
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
185181
for (size_t kernel_num = 0; kernel_num < kernels_num; kernel_num++) {
186182
auto dispatch_data = SetDefault(prim_params, kernel_num);
187-
auto kernel_name = kernel_num == 0 ? kernelName : "sdpa_opt_finalization";
183+
auto kernel_name = kernel_num == 0 ? kernelName + "_single_token" :
184+
kernel_num == 1 ? kernelName + "_multi_tokens" : "sdpa_opt_finalization";
188185
auto entry_point = GetEntryPoint(kernel_name, prim_params.layerID, params);
189186
auto jit_constants = GetJitConstants(prim_params, kernel_num);
190187
auto jit = CreateJit(kernel_name, jit_constants, entry_point);
191188

192189
auto& kernel = kd.kernels[kernel_num];
193190

194-
auto inputs_num = kernel_num == 1 ? 0 : static_cast<int>(prim_params.inputs.size());
191+
auto inputs_num = kernel_num == 2 ? 0 : static_cast<int>(prim_params.inputs.size());
195192
FillCLKernelData(kernel,
196193
dispatch_data,
197194
params.engineInfo,
@@ -230,7 +227,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
230227

231228
GPU_DEBUG_TRACE_DETAIL << "configure SDPA " << kernel_num << "th kernel: inputs_num=" << inputs_num << " arguments_num=" << kernel.params.arguments.size() << "\n";
232229

233-
if (kernel_num == 1) {
230+
if (kernel_num == 2) {
234231
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
235232

236233
ScalarDescriptor num_of_partitions_scalar;
@@ -249,13 +246,14 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
249246
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kernel_data) {
250247
const auto& prim_params = static_cast<const sdpa_params&>(params);
251248

252-
const size_t expected_kernels_num = 2;
249+
const size_t expected_kernels_num = 3;
253250
OPENVINO_ASSERT(kernel_data.kernels.size() == expected_kernels_num,
254251
"[GPU] Invalid kernels size for update dispatch data func of SDPA kernel");
255252

256253
auto& output = prim_params.outputs[0];
257254
auto& key_input = prim_params.inputs[1];
258255

256+
auto seq_num = output.Y().v;
259257
auto head_size = output.X().v;
260258
auto source_seq_len = key_input.Y().v;
261259
auto num_of_partitions = CeilDiv(source_seq_len, get_seq_len_partition_size());
@@ -271,21 +269,27 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
271269
auto dispatch_data1 = SetDefault(prim_params, 0);
272270
kernel_data.kernels[0].params.workGroups.global = dispatch_data1.gws;
273271
kernel_data.kernels[0].params.workGroups.local = dispatch_data1.lws;
274-
kernel_data.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params);
272+
kernel_data.kernels[0].skip_execution = seq_num > 1;
273+
274+
auto dispatch_data2 = SetDefault(prim_params, 1);
275+
kernel_data.kernels[1].params.workGroups.global = dispatch_data2.gws;
276+
kernel_data.kernels[1].params.workGroups.local = dispatch_data2.lws;
277+
kernel_data.kernels[1].skip_execution = seq_num == 1;
275278

276279
ScalarDescriptor num_of_partitions_scalar;
277280
num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
278281
num_of_partitions_scalar.v.u32 = num_of_partitions;
279282

280-
auto dispatch_data2 = SetDefault(prim_params, 1);
281-
kernel_data.kernels[1].params.workGroups.global = dispatch_data2.gws;
282-
kernel_data.kernels[1].params.workGroups.local = dispatch_data2.lws;
283-
kernel_data.kernels[1].skip_execution = num_of_partitions == 1;
283+
auto dispatch_data3 = SetDefault(prim_params, 2);
284+
kernel_data.kernels[2].params.workGroups.global = dispatch_data3.gws;
285+
kernel_data.kernels[2].params.workGroups.local = dispatch_data3.lws;
286+
kernel_data.kernels[2].skip_execution = num_of_partitions == 1;
284287

285-
kernel_data.kernels[1].params.scalars.clear();
286-
kernel_data.kernels[1].params.scalars.push_back(num_of_partitions_scalar);
288+
kernel_data.kernels[2].params.scalars.clear();
289+
kernel_data.kernels[2].params.scalars.push_back(num_of_partitions_scalar);
287290
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 0th kernel: arguments_num=" << kernel_data.kernels[0].params.arguments.size() << "\n";
288-
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels[1].params.arguments.size() << "\n";
291+
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels[2].params.arguments.size() << "\n";
292+
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 3th kernel: arguments_num=" << kernel_data.kernels[3].params.arguments.size() << "\n";
289293

290294
kernel_data.internalBufferSizes.clear();
291295
kernel_data.internalBufferSizes.push_back(buf_size);

0 commit comments

Comments
 (0)