9
9
10
10
namespace kernel_selector {
11
11
12
- constexpr size_t seq_len_partition_size = 256 ;
12
+ constexpr size_t seq_len_partition_size = 32 ;
13
13
constexpr size_t subgroup_size = 16 ;
14
14
15
15
ParamsKey SDPAKernelOpt::GetSupportedKey () const {
@@ -48,11 +48,12 @@ bool SDPAKernelOpt::Validate(const Params& p) const {
48
48
return true ;
49
49
}
50
50
51
- JitConstants SDPAKernelOpt::GetJitConstants (const sdpa_params& params) const {
52
- auto jit = Parent::GetJitConstants (params);
51
+ JitConstants SDPAKernelOpt::GetJitConstants (const sdpa_params& params, size_t kernel_idx ) const {
52
+ auto jit = MakeBaseParamsJitConstants (params);
53
53
54
- const auto softmax_acc_dt = params.inputs [0 ].GetDType ();
55
- jit.Merge (MakeTypeJitConstants (softmax_acc_dt, " ACCUMULATOR" ));
54
+ const auto softmax_acc_dt = Datatype::F32;
55
+ // const auto softmax_acc_dt = params.inputs[0].GetDType();
56
+ jit.Merge (MakeTypeJitConstants (softmax_acc_dt, " SOFTMAX_ACCUMULATOR" ));
56
57
57
58
const auto & config = params.conf ;
58
59
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size));
@@ -63,6 +64,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params) const {
63
64
jit.AddConstant (MakeJitConstant (" USE_SEQ_LEN_SPLIT" , 1 ));
64
65
jit.AddConstant (MakeJitConstant (" SEQ_LEN_PARTITION_SIZE" , seq_len_partition_size));
65
66
jit.AddConstant (MakeJitConstant (" SLM_SIZE" , seq_len_partition_size));
67
+ jit.AddConstant (MakeJitConstant (" SDPA_STAGE_" + std::to_string (kernel_idx), 1 ));
66
68
67
69
return jit;
68
70
}
@@ -74,18 +76,19 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
74
76
const auto & key_input = params.inputs [1 ];
75
77
const auto & output = params.outputs [0 ];
76
78
if (!query_input.is_dynamic ()) {
77
- const size_t seq_len = key_input.Y ().v ;
78
- const size_t num_of_partitions = CeilDiv (seq_len, seq_len_partition_size);
79
+ const size_t source_seq_len = key_input.Y ().v ;
80
+ const size_t target_seq_len = output.Y ().v ;
81
+ const size_t num_of_partitions = CeilDiv (source_seq_len, seq_len_partition_size);
79
82
const size_t head_size = static_cast <size_t >(params.conf .head_size );
80
83
81
84
if (kernel_idx == 0 ) {
82
85
dispatch_data.gws = { output.Batch ().v * output.Feature ().v ,
83
- output. Y (). v ,
86
+ target_seq_len ,
84
87
head_size * num_of_partitions };
85
88
dispatch_data.lws = { 1 , 1 , head_size };
86
89
} else {
87
90
dispatch_data.gws = { output.Batch ().v * output.Feature ().v ,
88
- output. Y (). v ,
91
+ target_seq_len ,
89
92
head_size };
90
93
dispatch_data.lws = { 1 , 1 , subgroup_size };
91
94
}
@@ -95,89 +98,122 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
95
98
}
96
99
97
100
KernelsData SDPAKernelOpt::GetKernelsData (const Params& params) const {
98
- KernelData kd = KernelData::Default<sdpa_params>(params);
99
- const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
100
-
101
101
if (!Validate (params)) {
102
102
return {};
103
103
}
104
104
105
- auto dispatchData = SetDefault (prim_params, 0 );
106
- auto entry_point = GetEntryPoint (kernelName, prim_params.layerID , params);
107
- auto cldnn_jit = GetJitConstants (prim_params);
108
- auto jit = CreateJit (kernelName, cldnn_jit, entry_point);
109
-
110
- auto & kernel = kd.kernels [0 ];
105
+ const size_t kernels_num = 2 ;
106
+ KernelData kd = KernelData::Default<sdpa_params>(params, kernels_num);
107
+ kd.needs_sub_kernels_sync = true ;
111
108
112
109
GetUpdateDispatchDataFunc (kd);
113
110
114
- FillCLKernelData (kernel, dispatchData, params.engineInfo , kernelName, jit, entry_point,
115
- " " , false , false , static_cast <int >(prim_params.inputs .size ()),
116
- GetFusedPrimitiveInputsCount (params), 1 , prim_params.is_shape_agnostic );
111
+ const auto & prim_params = dynamic_cast <const sdpa_params&>(params);
112
+ for (size_t kernel_num = 0 ; kernel_num < kernels_num; kernel_num++) {
113
+ auto dispatch_data = SetDefault (prim_params, kernel_num);
114
+ auto kernel_name = kernel_num == 0 ? kernelName : " sdpa_opt_finalization" ;
115
+ auto entry_point = GetEntryPoint (kernel_name, prim_params.layerID , params);
116
+ auto jit_constants = GetJitConstants (prim_params, kernel_num);
117
+ auto jit = CreateJit (kernel_name, jit_constants, entry_point);
118
+
119
+ auto & kernel = kd.kernels [kernel_num];
120
+
121
+ auto inputs_num = kernel_num == 1 ? 0 : static_cast <int >(prim_params.inputs .size ());
122
+ FillCLKernelData (kernel,
123
+ dispatch_data,
124
+ params.engineInfo ,
125
+ kernelName,
126
+ jit,
127
+ entry_point,
128
+ {},
129
+ false ,
130
+ false ,
131
+ inputs_num,
132
+ GetFusedPrimitiveInputsCount (params),
133
+ static_cast <int >(prim_params.outputs .size ()),
134
+ prim_params.is_shape_agnostic );
135
+
136
+ const auto num_of_partitions = 1 ;
137
+ auto & output = prim_params.outputs [0 ];
138
+ auto head_size = output.X ().v ;
117
139
118
- auto num_of_partitions = 1 ;
140
+ auto buf_dt_size = 4 ;
141
+ auto buf_elements_count = (num_of_partitions == 1 ) ? 1 : output.LogicalSize () / head_size * num_of_partitions;
142
+ auto buf_size = buf_elements_count * buf_dt_size;
119
143
120
- auto & output = prim_params.outputs [0 ];
121
- auto buf_dt_size = 4 ;
122
- // auto buf_elements_count = tokens_num * prim_params.configuration.heads_num * num_of_portions;
123
- auto buf_elements_count = output.LogicalSize () / output.X ().v * num_of_partitions;
124
- auto buf_size = buf_elements_count * buf_dt_size;
144
+ auto tmp_out_dt_size = 4 ;
145
+ auto tmp_out_elements_count = (num_of_partitions == 1 ) ? 1 : output.LogicalSize () * num_of_partitions;
146
+ auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
125
147
126
- auto tmp_out_dt_size = 4 ;
127
- auto tmp_out_elements_count = output. LogicalSize () / output. X (). v * num_of_partitions * prim_params. conf . head_size ;
128
- auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size ;
148
+ kernel. params . arguments . push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0 }) ;
149
+ kernel. params . arguments . push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1 }) ;
150
+ kernel. params . arguments . push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2 }) ;
129
151
130
- kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0 });
131
- kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1 });
132
- kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2 });
152
+ kd.internalBufferSizes .clear ();
153
+ kd.internalBufferSizes .push_back (buf_size);
154
+ kd.internalBufferSizes .push_back (buf_size);
155
+ kd.internalBufferSizes .push_back (tmp_out_size);
156
+ kd.internalBufferDataType = prim_params.inputs [0 ].GetDType ();
133
157
134
- kd.internalBufferSizes .clear ();
135
- kd.internalBufferSizes .push_back (buf_size);
136
- kd.internalBufferSizes .push_back (buf_size);
137
- kd.internalBufferSizes .push_back (tmp_out_size);
138
- kd.internalBufferDataType = prim_params.inputs [0 ].GetDType ();
158
+ GPU_DEBUG_TRACE_DETAIL << " configure SDPA " << kernel_num << " th kernel: inputs_num=" << inputs_num << " arguments_num=" << kernel.params .arguments .size () << " \n " ;
139
159
140
- // ScalarDescriptor num_of_partitions_scalar;
141
- // num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
142
- // num_of_partitions_scalar.v.u32 = 1;
160
+ if (kernel_num == 1 ) {
161
+ kernel.params .arguments .push_back ({ArgumentDescriptor::Types::SCALAR, 0 });
143
162
144
- // kd.kernels[1].params.scalars.resize(1);
145
- // kd.kernels[1].params.scalars[0] = num_of_partitions_scalar;
163
+ ScalarDescriptor num_of_partitions_scalar;
164
+ num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
165
+ num_of_partitions_scalar.v .u32 = num_of_partitions;
166
+
167
+ kernel.params .scalars .clear ();
168
+ kernel.params .scalars .push_back (num_of_partitions_scalar);
169
+ }
170
+ }
146
171
147
172
return { kd };
148
173
}
149
174
150
175
void SDPAKernelOpt::GetUpdateDispatchDataFunc (KernelData& kd) const {
151
176
kd.update_dispatch_data_func = [this ](const Params& params, KernelData& kernel_data) {
152
177
const auto & prim_params = static_cast <const sdpa_params&>(params);
153
- auto dispatchData = SetDefault (prim_params, 0 );
154
- OPENVINO_ASSERT (kernel_data.kernels .size () == 1 , " [GPU] Invalid kernels size for update dispatch data func" );
155
- kernel_data.kernels [0 ].params .workGroups .global = dispatchData.gws ;
156
- kernel_data.kernels [0 ].params .workGroups .local = dispatchData.lws ;
157
- kernel_data.kernels [0 ].skip_execution = KernelData::SkipKernelExecution (prim_params);
158
-
159
- // auto& in_q = prim_params.inputs[0];
160
- // auto& in_k = prim_params.inputs[1];
161
178
162
- // ScalarDescriptor num_of_partitions_scalar ;
163
- // num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
164
- // num_of_partitions_scalar.v.u32 = 1 ;
179
+ const size_t expected_kernels_num = 2 ;
180
+ OPENVINO_ASSERT (kernel_data. kernels . size () == expected_kernels_num,
181
+ " [GPU] Invalid kernels size for update dispatch data func of SDPA kernel " ) ;
165
182
166
- // kernel_data.kernels [0].params.scalars.resize(1) ;
167
- // kernel_data.kernels[0].params.scalars[0] = num_of_partitions_scalar ;
183
+ auto & output = prim_params. outputs [0 ];
184
+ auto & key_input = prim_params. inputs [ 1 ] ;
168
185
169
- auto num_of_partitions = 1 ;
186
+ auto head_size = output.X ().v ;
187
+ auto source_seq_len = key_input.Y ().v ;
188
+ auto num_of_partitions = CeilDiv (source_seq_len, seq_len_partition_size);
170
189
171
- auto & output = prim_params.outputs [0 ];
172
190
auto buf_dt_size = 4 ;
173
- // auto buf_elements_count = tokens_num * prim_params.configuration.heads_num * num_of_portions;
174
- auto buf_elements_count = output.LogicalSize () / output.X ().v * num_of_partitions;
191
+ auto buf_elements_count = (num_of_partitions == 1 ) ? 1 : output.LogicalSize () / head_size * num_of_partitions;
175
192
auto buf_size = buf_elements_count * buf_dt_size;
176
193
177
194
auto tmp_out_dt_size = 4 ;
178
- auto tmp_out_elements_count = output. LogicalSize () / output.X (). v * num_of_partitions * prim_params. conf . head_size ;
195
+ auto tmp_out_elements_count = (num_of_partitions == 1 ) ? 1 : output.LogicalSize () * num_of_partitions;
179
196
auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
180
197
198
+ auto dispatch_data1 = SetDefault (prim_params, 0 );
199
+ kernel_data.kernels [0 ].params .workGroups .global = dispatch_data1.gws ;
200
+ kernel_data.kernels [0 ].params .workGroups .local = dispatch_data1.lws ;
201
+ kernel_data.kernels [0 ].skip_execution = KernelData::SkipKernelExecution (prim_params);
202
+
203
+ ScalarDescriptor num_of_partitions_scalar;
204
+ num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
205
+ num_of_partitions_scalar.v .u32 = num_of_partitions;
206
+
207
+ auto dispatch_data2 = SetDefault (prim_params, 1 );
208
+ kernel_data.kernels [1 ].params .workGroups .global = dispatch_data2.gws ;
209
+ kernel_data.kernels [1 ].params .workGroups .local = dispatch_data2.lws ;
210
+ kernel_data.kernels [1 ].skip_execution = num_of_partitions == 1 ;
211
+
212
+ kernel_data.kernels [1 ].params .scalars .clear ();
213
+ kernel_data.kernels [1 ].params .scalars .push_back (num_of_partitions_scalar);
214
+ GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 0th kernel: arguments_num=" << kernel_data.kernels [0 ].params .arguments .size () << " \n " ;
215
+ GPU_DEBUG_TRACE_DETAIL << " update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels [1 ].params .arguments .size () << " \n " ;
216
+
181
217
kernel_data.internalBufferSizes .clear ();
182
218
kernel_data.internalBufferSizes .push_back (buf_size);
183
219
kernel_data.internalBufferSizes .push_back (buf_size);
0 commit comments