@@ -158,6 +158,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
158
158
kernel_offset += _kernels_data[s].kernels .size ();
159
159
}
160
160
for (size_t kd_idx = 0 ; kd_idx < _kernels_data[stage].kernels .size (); ++kd_idx) {
161
+ auto time0 = std::chrono::high_resolution_clock::now ();
161
162
if (_kernels_data[stage].kernels [kd_idx].skip_execution )
162
163
continue ;
163
164
@@ -166,14 +167,23 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
166
167
bool needs_completion_event = instance.needs_completion_event ();
167
168
168
169
auto & params = _kernels_data[stage].kernels [kd_idx].params ;
170
+
171
+
169
172
auto args = get_arguments (instance, stage);
170
173
args.scalars = ¶ms.scalars ;
171
174
172
175
for (const auto & m : instance.get_intermediates_memories ()) {
173
176
args.intermediates .push_back (m);
174
177
}
175
178
179
+ // if (stage == Stage::SDPA && kd_idx != 0) {
180
+ // auto& inputs = args.inputs;
181
+ // inputs.erase(inputs.begin(), inputs.begin() + 7);
182
+ // }
183
+
184
+ auto time1 = std::chrono::high_resolution_clock::now ();
176
185
stream.set_arguments (*_kernels[idx_final], _kernels_data[stage].kernels [kd_idx].params , args);
186
+ auto time2 = std::chrono::high_resolution_clock::now ();
177
187
178
188
const auto & gws = params.workGroups .global ;
179
189
const auto & lws = params.workGroups .local ;
@@ -183,30 +193,38 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
183
193
<< (needs_completion_event ? " has_completion_event=true" : " " ) << std::endl;
184
194
185
195
auto ev = stream.enqueue_kernel (*_kernels[idx_final], params, args, tmp_events, needs_completion_event);
196
+ auto time3 = std::chrono::high_resolution_clock::now ();
186
197
if (_kernels_data[stage].needs_sub_kernels_sync ) {
187
198
tmp_events = {ev};
188
199
}
200
+
201
+ auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count ();
202
+ auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count ();
203
+ auto time_res2 = std::chrono::duration_cast<std::chrono::microseconds>(time3 - time2).count ();
204
+ GPU_DEBUG_TRACE_DETAIL << " Time execute_stage inside = " << time_res0 << " " << time_res1 << " " << time_res2 << " \n " ;
205
+
189
206
all_events.push_back (ev);
190
207
}
191
208
192
209
193
- if (instance.get_network ().get_config ().get_property (ov::enable_profiling)) {
194
- auto final_event = stream.group_events (all_events);
195
- if (final_event != nullptr ) {
196
- stream.wait_for_events ({final_event});
197
- auto profiling_info = final_event->get_profiling_info ();
198
- for (const auto &interval : profiling_info) {
199
- if (interval.stage == cldnn::instrumentation::profiling_stage::executing) {
200
- auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(interval.value ->value ()).count ();
201
- GPU_DEBUG_INFO << " PagedAttention " << stage << " stage time: " << time_res0 << " mcs\n " ;
202
- }
203
- }
204
- }
205
- }
210
+ // if (instance.get_network().get_config().get_property(ov::enable_profiling)) {
211
+ // auto final_event = stream.group_events(all_events);
212
+ // if (final_event != nullptr) {
213
+ // stream.wait_for_events({final_event});
214
+ // auto profiling_info = final_event->get_profiling_info();
215
+ // for (const auto &interval : profiling_info) {
216
+ // if (interval.stage == cldnn::instrumentation::profiling_stage::executing) {
217
+ // auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(interval.value->value()).count();
218
+ // GPU_DEBUG_INFO << "PagedAttention " << stage << " stage time: " << time_res0 << " mcs\n";
219
+ // }
220
+ // }
221
+ // }
222
+ // }
206
223
}
207
224
208
225
event::ptr execute_impl (const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
209
226
auto & stream = instance.get_network ().get_stream ();
227
+ auto time0 = std::chrono::high_resolution_clock::now ();
210
228
// auto& service_stream = instance.get_network().get_engine().get_service_stream();
211
229
std::vector<event::ptr> res_events;
212
230
@@ -217,6 +235,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
217
235
// GPU_DEBUG_TRACE_DETAIL << instance.id() << " stage is " << (is_prefill_stage ? "prefill" : "tokens generating") << "\n";
218
236
219
237
execute_stage (events, instance, res_events, Stage::KV_CACHE_UPDATE);
238
+ auto time1 = std::chrono::high_resolution_clock::now ();
220
239
221
240
if (false ) {
222
241
// auto sliding_window_memory = instance.input_memory_ptr(12);
@@ -291,12 +310,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
291
310
for (auto & ev : res_events)
292
311
all_events.push_back (ev);
293
312
294
- auto impl_param = *instance.get_impl_params ();
295
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, impl_param .is_dynamic ());
296
- (_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
313
+ // const auto impl_params = *instance.get_impl_params();
314
+ // auto sdpa_kernel_params = get_sdpa_kernel_params(impl_params, impl_params .is_dynamic());
315
+ // (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
297
316
298
317
execute_stage (all_events, instance, res_events, Stage::SDPA);
299
318
319
+ auto time2 = std::chrono::high_resolution_clock::now ();
320
+ auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count ();
321
+ auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count ();
322
+ GPU_DEBUG_TRACE_DETAIL << " Time PA = " << time_res0 << " " << time_res1 << " \n " ;
323
+
300
324
return aggregate_events (res_events, stream, res_events.size () > 1 );
301
325
}
302
326
}
@@ -331,6 +355,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
331
355
config.kv_heads_num = kv_heads_num;
332
356
config.block_size = block_size;
333
357
config.x_size = x_size;
358
+ config.max_context_len = 1 ;
334
359
}
335
360
336
361
return config;
@@ -397,6 +422,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
397
422
params.inputs [6 ] = convert_data_tensor (scale_layout);
398
423
399
424
params.configuration = get_sdpa_configuration (impl_param);
425
+ GPU_DEBUG_TRACE_DETAIL << " Number of constant_mem " << impl_param.memory_deps .size () << " , dynamic=" << is_dynamic << " \n " ;
426
+ if (!is_dynamic) {
427
+ auto & constant_mem = impl_param.memory_deps ;
428
+
429
+
430
+ const auto max_context_len_mem = constant_mem.at (7 );
431
+ mem_lock<int32_t , mem_lock_type::read > max_context_len_mem_lock (max_context_len_mem, impl_param.get_stream ());
432
+ GPU_DEBUG_TRACE_DETAIL << " max_context_len_mem_lock=" << max_context_len_mem_lock[0 ] << " \n " ;
433
+
434
+ const auto is_prompt_stage_mem = constant_mem.at (5 );
435
+ mem_lock<uint8_t , mem_lock_type::read > is_prompt_stage_mem_lock (is_prompt_stage_mem, impl_param.get_stream ());
436
+ bool is_prompt_stage = is_prompt_stage_mem_lock[0 ];
437
+
438
+ if (is_prompt_stage) {
439
+ // Use number of slots for KV cache as a maximum context length for the first iteration
440
+ auto slot_mapping = impl_param.get_input_layout (6 );
441
+ params.configuration .max_context_len = slot_mapping.get_shape ()[1 ];
442
+ } else {
443
+ const auto max_context_len_mem = constant_mem.at (7 );
444
+ mem_lock<int32_t , mem_lock_type::read > max_context_len_mem_lock (max_context_len_mem, impl_param.get_stream ());
445
+ params.configuration .max_context_len = max_context_len_mem_lock[0 ];
446
+ }
447
+ }
400
448
401
449
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
402
450
const auto & out_offsets_map = impl_param.out_port_to_shape_info_offset ;
@@ -434,6 +482,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
434
482
void update_dispatch_data (const kernel_impl_params& impl_param) override {
435
483
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, impl_param.is_dynamic ());
436
484
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
485
+
486
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, impl_param.is_dynamic ());
487
+ (_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
437
488
}
438
489
};
439
490
0 commit comments