@@ -291,9 +291,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
291
291
for (auto & ev : res_events)
292
292
all_events.push_back (ev);
293
293
294
- auto impl_param = *instance.get_impl_params ();
295
- auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, impl_param .is_dynamic ());
296
- (_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
294
+ // const auto impl_params = *instance.get_impl_params();
295
+ // auto sdpa_kernel_params = get_sdpa_kernel_params(impl_params, impl_params .is_dynamic());
296
+ // (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
297
297
298
298
execute_stage (all_events, instance, res_events, Stage::SDPA);
299
299
@@ -331,6 +331,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
331
331
config.kv_heads_num = kv_heads_num;
332
332
config.block_size = block_size;
333
333
config.x_size = x_size;
334
+ config.max_context_len = 1 ;
334
335
}
335
336
336
337
return config;
@@ -397,6 +398,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
397
398
params.inputs [6 ] = convert_data_tensor (scale_layout);
398
399
399
400
params.configuration = get_sdpa_configuration (impl_param);
401
+ GPU_DEBUG_TRACE_DETAIL << " Number of constant_mem " << impl_param.memory_deps .size () << " , dynamic=" << is_dynamic << " \n " ;
402
+ if (!is_dynamic) {
403
+ auto & constant_mem = impl_param.memory_deps ;
404
+
405
+
406
+ const auto max_context_len_mem = constant_mem.at (7 );
407
+ mem_lock<int32_t , mem_lock_type::read > max_context_len_mem_lock (max_context_len_mem, impl_param.get_stream ());
408
+ GPU_DEBUG_TRACE_DETAIL << " max_context_len_mem_lock=" << max_context_len_mem_lock[0 ] << " \n " ;
409
+
410
+ const auto is_prompt_stage_mem = constant_mem.at (5 );
411
+ mem_lock<uint8_t , mem_lock_type::read > is_prompt_stage_mem_lock (is_prompt_stage_mem, impl_param.get_stream ());
412
+ bool is_prompt_stage = is_prompt_stage_mem_lock[0 ];
413
+
414
+ if (is_prompt_stage) {
415
+ // Use number of slots for KV cache as a maximum context length for the first iteration
416
+ auto slot_mapping = impl_param.get_input_layout (6 );
417
+ params.configuration .max_context_len = slot_mapping.get_shape ()[1 ];
418
+ } else {
419
+ const auto max_context_len_mem = constant_mem.at (7 );
420
+ mem_lock<int32_t , mem_lock_type::read > max_context_len_mem_lock (max_context_len_mem, impl_param.get_stream ());
421
+ params.configuration .max_context_len = max_context_len_mem_lock[0 ];
422
+ }
423
+ }
400
424
401
425
const auto & in_offsets_map = impl_param.in_port_to_shape_info_offset ;
402
426
const auto & out_offsets_map = impl_param.out_port_to_shape_info_offset ;
@@ -434,6 +458,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
434
458
void update_dispatch_data (const kernel_impl_params& impl_param) override {
435
459
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params (impl_param, impl_param.is_dynamic ());
436
460
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func )(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
461
+
462
+ auto sdpa_kernel_params = get_sdpa_kernel_params (impl_param, impl_param.is_dynamic ());
463
+ (_kernels_data[Stage::SDPA].update_dispatch_data_func )(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
437
464
}
438
465
};
439
466
0 commit comments