Skip to content

Commit b624f46

Browse files
committed
Fix the accuracy issue
1 parent 15d02d0 commit b624f46

File tree

7 files changed

+414
-110
lines changed

7 files changed

+414
-110
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+67-16
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
158158
kernel_offset += _kernels_data[s].kernels.size();
159159
}
160160
for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) {
161+
auto time0 = std::chrono::high_resolution_clock::now();
161162
if (_kernels_data[stage].kernels[kd_idx].skip_execution)
162163
continue;
163164

@@ -166,14 +167,23 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
166167
bool needs_completion_event = instance.needs_completion_event();
167168

168169
auto& params = _kernels_data[stage].kernels[kd_idx].params;
170+
171+
169172
auto args = get_arguments(instance, stage);
170173
args.scalars = &params.scalars;
171174

172175
for (const auto& m : instance.get_intermediates_memories()) {
173176
args.intermediates.push_back(m);
174177
}
175178

179+
// if (stage == Stage::SDPA && kd_idx != 0) {
180+
// auto& inputs = args.inputs;
181+
// inputs.erase(inputs.begin(), inputs.begin() + 7);
182+
// }
183+
184+
auto time1 = std::chrono::high_resolution_clock::now();
176185
stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);
186+
auto time2 = std::chrono::high_resolution_clock::now();
177187

178188
const auto& gws = params.workGroups.global;
179189
const auto& lws = params.workGroups.local;
@@ -183,30 +193,38 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
183193
<< (needs_completion_event ? " has_completion_event=true" : "") << std::endl;
184194

185195
auto ev = stream.enqueue_kernel(*_kernels[idx_final], params, args, tmp_events, needs_completion_event);
196+
auto time3 = std::chrono::high_resolution_clock::now();
186197
if (_kernels_data[stage].needs_sub_kernels_sync) {
187198
tmp_events = {ev};
188199
}
200+
201+
auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count();
202+
auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count();
203+
auto time_res2 = std::chrono::duration_cast<std::chrono::microseconds>(time3 - time2).count();
204+
GPU_DEBUG_TRACE_DETAIL << "Time execute_stage inside = " << time_res0 << " " << time_res1 << " " << time_res2 << "\n";
205+
189206
all_events.push_back(ev);
190207
}
191208

192209

193-
if (instance.get_network().get_config().get_property(ov::enable_profiling)) {
194-
auto final_event = stream.group_events(all_events);
195-
if (final_event != nullptr) {
196-
stream.wait_for_events({final_event});
197-
auto profiling_info = final_event->get_profiling_info();
198-
for (const auto &interval : profiling_info) {
199-
if (interval.stage == cldnn::instrumentation::profiling_stage::executing) {
200-
auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(interval.value->value()).count();
201-
GPU_DEBUG_INFO << "PagedAttention " << stage << " stage time: " << time_res0 << " mcs\n";
202-
}
203-
}
204-
}
205-
}
210+
// if (instance.get_network().get_config().get_property(ov::enable_profiling)) {
211+
// auto final_event = stream.group_events(all_events);
212+
// if (final_event != nullptr) {
213+
// stream.wait_for_events({final_event});
214+
// auto profiling_info = final_event->get_profiling_info();
215+
// for (const auto &interval : profiling_info) {
216+
// if (interval.stage == cldnn::instrumentation::profiling_stage::executing) {
217+
// auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(interval.value->value()).count();
218+
// GPU_DEBUG_INFO << "PagedAttention " << stage << " stage time: " << time_res0 << " mcs\n";
219+
// }
220+
// }
221+
// }
222+
// }
206223
}
207224

208225
event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
209226
auto& stream = instance.get_network().get_stream();
227+
auto time0 = std::chrono::high_resolution_clock::now();
210228
// auto& service_stream = instance.get_network().get_engine().get_service_stream();
211229
std::vector<event::ptr> res_events;
212230

@@ -217,6 +235,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
217235
// GPU_DEBUG_TRACE_DETAIL << instance.id() << " stage is " << (is_prefill_stage ? "prefill" : "tokens generating") << "\n";
218236

219237
execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE);
238+
auto time1 = std::chrono::high_resolution_clock::now();
220239

221240
if (false) {
222241
// auto sliding_window_memory = instance.input_memory_ptr(12);
@@ -291,12 +310,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
291310
for (auto& ev : res_events)
292311
all_events.push_back(ev);
293312

294-
auto impl_param = *instance.get_impl_params();
295-
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
296-
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
313+
// const auto impl_params = *instance.get_impl_params();
314+
// auto sdpa_kernel_params = get_sdpa_kernel_params(impl_params, impl_params.is_dynamic());
315+
// (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
297316

298317
execute_stage(all_events, instance, res_events, Stage::SDPA);
299318

319+
auto time2 = std::chrono::high_resolution_clock::now();
320+
auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count();
321+
auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count();
322+
GPU_DEBUG_TRACE_DETAIL << "Time PA = " << time_res0 << " " << time_res1 << "\n";
323+
300324
return aggregate_events(res_events, stream, res_events.size() > 1);
301325
}
302326
}
@@ -331,6 +355,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
331355
config.kv_heads_num = kv_heads_num;
332356
config.block_size = block_size;
333357
config.x_size = x_size;
358+
config.max_context_len = 1;
334359
}
335360

336361
return config;
@@ -397,6 +422,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
397422
params.inputs[6] = convert_data_tensor(scale_layout);
398423

399424
params.configuration = get_sdpa_configuration(impl_param);
425+
GPU_DEBUG_TRACE_DETAIL << "Number of constant_mem " << impl_param.memory_deps.size() << ", dynamic=" << is_dynamic << "\n";
426+
if (!is_dynamic) {
427+
auto& constant_mem = impl_param.memory_deps;
428+
429+
430+
const auto max_context_len_mem = constant_mem.at(7);
431+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
432+
GPU_DEBUG_TRACE_DETAIL << "max_context_len_mem_lock=" << max_context_len_mem_lock[0] << "\n";
433+
434+
const auto is_prompt_stage_mem = constant_mem.at(5);
435+
mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream());
436+
bool is_prompt_stage = is_prompt_stage_mem_lock[0];
437+
438+
if (is_prompt_stage) {
439+
// Use number of slots for KV cache as a maximum context length for the first iteration
440+
auto slot_mapping = impl_param.get_input_layout(6);
441+
params.configuration.max_context_len = slot_mapping.get_shape()[1];
442+
} else {
443+
const auto max_context_len_mem = constant_mem.at(7);
444+
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
445+
params.configuration.max_context_len = max_context_len_mem_lock[0];
446+
}
447+
}
400448

401449
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
402450
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;
@@ -434,6 +482,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
434482
void update_dispatch_data(const kernel_impl_params& impl_param) override {
435483
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic());
436484
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);
485+
486+
auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic());
487+
(_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]);
437488
}
438489
};
439490

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct typed_program_node<paged_attention> : public typed_program_node_base<page
2323
program_node& key_cache() const { return get_dependency(3); }
2424
program_node& value_cache() const { return get_dependency(4); }
2525

26-
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
26+
std::vector<size_t> get_shape_infer_dependencies() const override { return { 5 /* is_prompt */, 7 /* max_context_len */ }; }
2727
};
2828

2929
using paged_attention_node = typed_program_node<paged_attention>;

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ void primitive_inst::update_shape() {
380380

381381
auto dep_mem = _network.get_output_memory(dep_id);
382382
memory_deps.insert({i, dep_mem});
383+
384+
// Ignore shape_infer dependency for input_layout dependency type and in_order queue
385+
if (get_node().is_type<paged_attention>() && dep.is_type<input_layout>() && queue_type == QueueTypes::in_order)
386+
continue;
387+
383388
if (!get_node().is_type<shape_of>() && !dep.is_in_shape_of_subgraph()) {
384389
has_runtime_deps = true;
385390

@@ -391,6 +396,7 @@ void primitive_inst::update_shape() {
391396
}
392397
}
393398

399+
GPU_DEBUG_TRACE_DETAIL << id() << " runtime dependencies = " << has_runtime_deps << "\n";
394400
if (has_runtime_deps) {
395401
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("update_shape_sync: " + id()));
396402
if (!dependencies_events.empty() && queue_type == QueueTypes::out_of_order) {

0 commit comments

Comments
 (0)