Skip to content

Commit 2dc55c3

Browse files
committed
Debug accuracy issue
1 parent 15d02d0 commit 2dc55c3

File tree

6 files changed

+365
-91
lines changed

6 files changed

+365
-91
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+30-3
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
291291
for (auto& ev : res_events)
292292
all_events.push_back(ev);
293293

294-
auto impl_param = *instance.get_impl_params();
295-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
296-
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
294+
// const auto impl_params = *instance.get_impl_params();
295+
// auto sdpa_kernel_params = get_sdpa_kernel_params(impl_params, impl_params.is_dynamic());
296+
// (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
297297

298298
execute_stage(all_events, instance, res_events, Stage::SDPA);
299299

@@ -331,6 +331,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
331331
config.kv_heads_num = kv_heads_num;
332332
config.block_size = block_size;
333333
config.x_size = x_size;
334+
config.max_context_len = 1;
334335
}
335336

336337
return config;
@@ -397,6 +398,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
397398
params.inputs[6] = convert_data_tensor(scale_layout);
398399

399400
params.configuration = get_sdpa_configuration(impl_param);
401+
GPU_DEBUG_TRACE_DETAIL << "Number of constant_mem " << impl_param.memory_deps.size() << ", dynamic=" << is_dynamic << "\n";
402+
if (!is_dynamic) {
403+
auto& constant_mem = impl_param.memory_deps;
404+
405+
406+
const auto max_context_len_mem = constant_mem.at(7);
407+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
408+
GPU_DEBUG_TRACE_DETAIL << "max_context_len_mem_lock=" << max_context_len_mem_lock[0] << "\n";
409+
410+
const auto is_prompt_stage_mem = constant_mem.at(5);
411+
mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream());
412+
bool is_prompt_stage = is_prompt_stage_mem_lock[0];
413+
414+
if (is_prompt_stage) {
415+
// Use number of slots for KV cache as a maximum context length for the first iteration
416+
auto slot_mapping = impl_param.get_input_layout(6);
417+
params.configuration.max_context_len = slot_mapping.get_shape()[1];
418+
} else {
419+
const auto max_context_len_mem = constant_mem.at(7);
420+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
421+
params.configuration.max_context_len = max_context_len_mem_lock[0];
422+
}
423+
}
400424

401425
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
402426
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
@@ -434,6 +458,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
434458
void update_dispatch_data(const kernel_impl_params& impl_param) override {
435459
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
436460
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
461+
462+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
463+
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
437464
}
438465
};
439466

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct typed_program_node<paged_attention> : public typed_program_node_base<page
2323
program_node& key_cache() const { return get_dependency(3); }
2424
program_node& value_cache() const { return get_dependency(4); }
2525

26-
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
26+
std::vector<size_t> get_shape_infer_dependencies() const override { return { 5 /* is_prompt */, 7 /* max_context_len */ }; }
2727
};
2828

2929
using paged_attention_node = typed_program_node<paged_attention>;

0 commit comments

Comments
 (0)