@@ -26,7 +26,7 @@ std::string convert_to(const std::string &str) {
26
26
27
27
static size_t get_seq_id_block_size () {
28
28
static bool called = false ;
29
- size_t block_size = 1 ;
29
+ size_t block_size = 8 ;
30
30
if (const auto env_var = std::getenv (" BLOCK_SIZE" )) {
31
31
block_size = convert_to<size_t >(env_var);
32
32
}
@@ -53,6 +53,21 @@ static size_t get_seq_len_partition_size() {
53
53
return seq_len;
54
54
}
55
55
56
+ static size_t get_mul_num () {
57
+ static bool called = false ;
58
+ size_t muls_num = 8 ;
59
+ if (const auto env_var = std::getenv (" MULS_NUM" )) {
60
+ muls_num = convert_to<size_t >(env_var);
61
+ }
62
+
63
+ if (!called) {
64
+ std::cout << " Set muls_num = " << muls_num << " \n " ;
65
+ called = true ;
66
+ }
67
+ return muls_num;
68
+ }
69
+
70
+
56
71
ParamsKey SDPAKernelOpt::GetSupportedKey () const {
57
72
ParamsKey k;
58
73
k.EnableInputDataType (Datatype::F16);
@@ -96,35 +111,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
96
111
const auto softmax_acc_dt = params.inputs [0 ].GetDType ();
97
112
jit.Merge (MakeTypeJitConstants (softmax_acc_dt, " SOFTMAX_ACCUMULATOR" ));
98
113
99
- if (const auto env_var = std::getenv (" MULS_NUM" )) {
100
- auto muls_num = convert_to<size_t >(env_var);
101
- jit.AddConstant (MakeJitConstant (" MULS_NUM" , muls_num));
102
- }
103
-
104
- if (const auto env_var = std::getenv (" FIRST_APPROACH_OPTION2" )) {
105
- jit.AddConstant (MakeJitConstant (" FIRST_APPROACH_OPTION2" , 1 ));
106
- }
107
-
108
- if (const auto env_var = std::getenv (" FIRST_APPROACH" )) {
109
- jit.AddConstant (MakeJitConstant (" FIRST_APPROACH" , 1 ));
110
- }
111
-
112
- static bool printed = false ;
113
- if (!printed) {
114
- printed = true ;
115
- std::cout << " Input " << static_cast <int >(params.inputs [0 ].GetDType ()) << " " << static_cast <int >(params.outputs [0 ].GetDType ()) << " " << static_cast <int >(softmax_acc_dt) << " \n " ;
116
-
117
- if (const auto env_var = std::getenv (" MULS_NUM" )) {
118
- auto muls_num = convert_to<size_t >(env_var);
119
- std::cout << " Force MULS_NUM to " << muls_num << " \n " ;
120
- }
121
- if (const auto env_var = std::getenv (" FIRST_APPROACH_OPTION2" )) {
122
- std::cout << " Force FIRST_APPROACH_OPTION2\n " ;
123
- }
124
- if (const auto env_var = std::getenv (" FIRST_APPROACH" )) {
125
- std::cout << " Force FIRST_APPROACH\n " ;
126
- }
127
- }
114
+ jit.AddConstant (MakeJitConstant (" MULS_NUM" , get_mul_num ()));
128
115
129
116
const auto & config = params.conf ;
130
117
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size));
@@ -135,9 +122,17 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
135
122
jit.AddConstant (MakeJitConstant (" USE_SEQ_LEN_SPLIT" , 1 ));
136
123
jit.AddConstant (MakeJitConstant (" SEQ_LEN_PARTITION_SIZE" , get_seq_len_partition_size ()));
137
124
jit.AddConstant (MakeJitConstant (" SLM_SIZE" , get_seq_len_partition_size ()));
138
- jit.AddConstant (MakeJitConstant (" SDPA_STAGE_" + std::to_string (kernel_idx), 1 ));
139
125
140
- jit.AddConstant (MakeJitConstant (" SEQ_ID_BLOCK_SIZE" , get_seq_id_block_size ()));
126
+ // kernel_idx == 0 - single token opt
127
+ // kernel_idx == 1 - multi token opt
128
+ // kernel_idx == 2 - finalization
129
+ if (kernel_idx == 0 )
130
+ jit.AddConstant (MakeJitConstant (" SEQ_ID_BLOCK_SIZE" , 1 ));
131
+ else
132
+ jit.AddConstant (MakeJitConstant (" SEQ_ID_BLOCK_SIZE" , get_seq_id_block_size ()));
133
+
134
+ auto sdpa_stage = kernel_idx == 2 ? 1 : 0 ;
135
+ jit.AddConstant (MakeJitConstant (" SDPA_STAGE_" + std::to_string (sdpa_stage), 1 ));
141
136
142
137
return jit;
143
138
}
@@ -153,10 +148,11 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
153
148
const size_t target_seq_len = output.Y ().v ;
154
149
const size_t num_of_partitions = CeilDiv (source_seq_len, get_seq_len_partition_size ());
155
150
const size_t head_size = static_cast <size_t >(params.conf .head_size );
151
+ const size_t block_size = kernel_idx == 1 ? get_seq_id_block_size () : 1 ;
156
152
157
- if (kernel_idx == 0 ) {
153
+ if (kernel_idx == 0 || kernel_idx == 1 ) {
158
154
dispatch_data.gws = { output.Batch ().v * output.Feature ().v ,
159
- CeilDiv (target_seq_len, get_seq_id_block_size () ),
155
+ CeilDiv (target_seq_len, block_size ),
160
156
head_size * num_of_partitions };
161
157
dispatch_data.lws = { 1 , 1 , head_size };
162
158
} else {
@@ -175,7 +171,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
175
171
return {};
176
172
}
177
173
178
- const size_t kernels_num = 2 ;
174
+ const size_t kernels_num = 3 ;
179
175
KernelData kd = KernelData::Default<sdpa_params>(params, kernels_num);
180
176
kd.needs_sub_kernels_sync = true ;
181
177
@@ -184,14 +180,15 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
184
180
const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
185
181
for (size_t kernel_num = 0 ; kernel_num < kernels_num; kernel_num++) {
186
182
auto dispatch_data = SetDefault (prim_params, kernel_num);
187
- auto kernel_name = kernel_num == 0 ? kernelName : " sdpa_opt_finalization" ;
183
+ auto kernel_name = kernel_num == 0 ? kernelName + " _single_token" :
184
+ kernel_num == 1 ? kernelName + " _multi_tokens" : " sdpa_opt_finalization" ;
188
185
auto entry_point = GetEntryPoint (kernel_name, prim_params.layerID , params);
189
186
auto jit_constants = GetJitConstants (prim_params, kernel_num);
190
187
auto jit = CreateJit (kernel_name, jit_constants, entry_point);
191
188
192
189
auto & kernel = kd.kernels [kernel_num];
193
190
194
- auto inputs_num = kernel_num == 1 ? 0 : static_cast <int >(prim_params.inputs .size ());
191
+ auto inputs_num = kernel_num == 2 ? 0 : static_cast <int >(prim_params.inputs .size ());
195
192
FillCLKernelData (kernel,
196
193
dispatch_data,
197
194
params.engineInfo ,
@@ -230,7 +227,7 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
230
227
231
228
GPU_DEBUG_TRACE_DETAIL << " configure SDPA " << kernel_num << " th kernel: inputs_num=" << inputs_num << " arguments_num=" << kernel.params .arguments .size () << " \n " ;
232
229
233
- if (kernel_num == 1 ) {
230
+ if (kernel_num == 2 ) {
234
231
kernel.params .arguments .push_back ({ArgumentDescriptor::Types::SCALAR, 0 });
235
232
236
233
ScalarDescriptor num_of_partitions_scalar;
@@ -249,13 +246,14 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
249
246
kd.update_dispatch_data_func = [this ](const Params& params, KernelData& kernel_data) {
250
247
const auto & prim_params = static_cast <const sdpa_params&>(params);
251
248
252
- const size_t expected_kernels_num = 2 ;
249
+ const size_t expected_kernels_num = 3 ;
253
250
OPENVINO_ASSERT (kernel_data.kernels .size () == expected_kernels_num,
254
251
" [GPU] Invalid kernels size for update dispatch data func of SDPA kernel" );
255
252
256
253
auto & output = prim_params.outputs [0 ];
257
254
auto & key_input = prim_params.inputs [1 ];
258
255
256
+ auto seq_num = output.Y ().v ;
259
257
auto head_size = output.X ().v ;
260
258
auto source_seq_len = key_input.Y ().v ;
261
259
auto num_of_partitions = CeilDiv (source_seq_len, get_seq_len_partition_size ());
@@ -271,21 +269,27 @@ void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
271
269
auto dispatch_data1 = SetDefault (prim_params, 0 );
272
270
kernel_data.kernels [0 ].params .workGroups .global = dispatch_data1.gws ;
273
271
kernel_data.kernels [0 ].params .workGroups .local = dispatch_data1.lws ;
274
- kernel_data.kernels [0 ].skip_execution = KernelData::SkipKernelExecution (prim_params);
272
+ kernel_data.kernels [0 ].skip_execution = seq_num > 1 ;
273
+
274
+ auto dispatch_data2 = SetDefault (prim_params, 1 );
275
+ kernel_data.kernels [1 ].params .workGroups .global = dispatch_data2.gws ;
276
+ kernel_data.kernels [1 ].params .workGroups .local = dispatch_data2.lws ;
277
+ kernel_data.kernels [1 ].skip_execution = seq_num == 1 ;
275
278
276
279
ScalarDescriptor num_of_partitions_scalar;
277
280
num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
278
281
num_of_partitions_scalar.v .u32 = num_of_partitions;
279
282
280
- auto dispatch_data2 = SetDefault (prim_params, 1 );
281
- kernel_data.kernels [1 ].params .workGroups .global = dispatch_data2 .gws ;
282
- kernel_data.kernels [1 ].params .workGroups .local = dispatch_data2 .lws ;
283
- kernel_data.kernels [1 ].skip_execution = num_of_partitions == 1 ;
283
+ auto dispatch_data3 = SetDefault (prim_params, 2 );
284
+ kernel_data.kernels [2 ].params .workGroups .global = dispatch_data3 .gws ;
285
+ kernel_data.kernels [2 ].params .workGroups .local = dispatch_data3 .lws ;
286
+ kernel_data.kernels [2 ].skip_execution = num_of_partitions == 1 ;
284
287
285
- kernel_data.kernels [1 ].params .scalars .clear ();
286
- kernel_data.kernels [1 ].params .scalars .push_back (num_of_partitions_scalar);
288
+ kernel_data.kernels [2 ].params .scalars .clear ();
289
+ kernel_data.kernels [2 ].params .scalars .push_back (num_of_partitions_scalar);
287
290
GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 0th kernel: arguments_num=" << kernel_data.kernels [0 ].params .arguments .size () << " \n " ;
288
- GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels [1 ].params .arguments .size () << " \n " ;
291
+ GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels [2 ].params .arguments .size () << " \n " ;
292
+ GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 3th kernel: arguments_num=" << kernel_data.kernels [3 ].params .arguments .size () << " \n " ;
289
293
290
294
kernel_data.internalBufferSizes .clear ();
291
295
kernel_data.internalBufferSizes .push_back (buf_size);
0 commit comments