Skip to content

Commit 4fa599e

Browse files
committed
PagedAttention tests
1 parent 8aa0991 commit 4fa599e

File tree

9 files changed

+230
-82
lines changed

9 files changed

+230
-82
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
7878
void set_arguments_impl(paged_attention_inst& instance) override {}
7979

8080
kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage) const override {
81+
{
82+
kernel_arguments_data args;
83+
args.shape_info = instance.shape_info_memory_ptr();
84+
if (stage == Stage::KV_CACHE_UPDATE) {
85+
args.inputs = { instance.input_memory_ptr(1), /* key */
86+
instance.input_memory_ptr(2), /* value */
87+
instance.input_memory_ptr(6) /* slot_mapping */};
88+
args.outputs = { instance.input_memory_ptr(3), /* key_cache */
89+
instance.input_memory_ptr(4) /* value_cache */ };
90+
} else if (stage == Stage::SDPA) {
91+
args.inputs = { instance.input_memory_ptr(0), /* query */
92+
instance.input_memory_ptr(3), /* key_cache */
93+
instance.input_memory_ptr(4), /* value_cache */
94+
instance.input_memory_ptr(7), /* max_context_len */
95+
instance.input_memory_ptr(8), /* context_lens */
96+
instance.input_memory_ptr(9), /* block_tables */
97+
instance.input_memory_ptr(10) /* scale */ };
98+
args.outputs = { instance.output_memory_ptr(0) };
99+
}
100+
101+
return args;
102+
}
103+
81104
// WA due to lack of proper handling of key and value cache buffers. Keep them in impl for test purpose.
82105
if (value_cache_mem == nullptr) {
83106
const auto key_cache_layout = instance.get_impl_params()->get_input_layout(3);
@@ -201,6 +224,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
201224
const int64_t heads_num = hidden_size / head_size;
202225
const int64_t num_queries_per_kv = heads_num / kv_heads_num;
203226

227+
std::cout << "Prefill stage: batch_size=" << batch_size << " seq_len=" << seq_len << " hidden_size=" << hidden_size
228+
<< " kv_heads_num=" << kv_heads_num << " heads_num=" << heads_num << " head_size=" << head_size
229+
<< " q=" << query_layout.to_short_string() << " k_cache=" << key_cache_layout.to_short_string()
230+
<< " v_cache=" << value_cache_layout.to_short_string() << "\n";
231+
204232
auto attention_bias = generate_attention_bias(batch_size, seq_len, sliding_window, instance.get_network().get_engine());
205233

206234
auto query_mem = instance.input_memory_ptr(0);

src/plugins/intel_gpu/src/graph/layout_optimizer.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,7 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
16901690
}
16911691
// TODO: uncomment this code when onednn gemm implementations will have real perf improvements vs cldnn
16921692
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
1693+
return impl_types::ocl;
16931694
if (!_optimization_attributes.use_onednn_impls)
16941695
return impl_types::ocl;
16951696

src/plugins/intel_gpu/src/graph/network.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
10541054
auto prog_id = ((get_program() != nullptr) ? get_program()->get_id() : 0);
10551055
auto net_id = get_id();
10561056
GPU_DEBUG_IF(debug_config->is_target_iteration(curr_iter) &&
1057-
debug_config->is_layer_for_dumping(layer_name, inst->is_output(), inst->is_input()) && prog_id == 2) {
1057+
debug_config->is_layer_for_dumping(layer_name, inst->is_output(), inst->is_input())) {
10581058
std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\""
10591059
+ layer_name + ":";
10601060
for (size_t i = 0; i < get_primitive(layer_name)->outputs_memory_count(); i++) {

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+36-2
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,39 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
12441244
GPU_DEBUG_TRACE_DETAIL << "- inputs[" << i << "] : " << _deps[i].first->id() << std::endl;
12451245
}
12461246
GPU_DEBUG_TRACE_DETAIL << "-----------------------------------------------------------------" << std::endl;
1247+
1248+
std::vector<std::string> print_ids = {"pagedattentionextension:PagedAttentionExtension_606",
1249+
"gemm:MatMul_112999",
1250+
"softmax:Softmax_113002",
1251+
"gemm:__module.model.layers.0.self_attn/aten::transpose/Transpose_3",
1252+
/* BATCHED chatglm3 fp32 */
1253+
"matmul:MatMul_113004",
1254+
"add:Add_113006",
1255+
"softmax:Softmax_113007",
1256+
"matmul:__module.model.layers.0.self_attn/aten::scaled_dot_product_attention/ScaledDotProductAttention",
1257+
"transpose:__module.model.layers.0.self_attn/aten::transpose/Transpose_3",
1258+
/* Batched open_llama-7b fp32 + INT8 */
1259+
"matmul:MatMul_158917",
1260+
"add:Add_158919",
1261+
"softmax:Softmax_158920",
1262+
"matmul:__module.model.layers.0.self_attn/aten::scaled_dot_product_attention/ScaledDotProductAttention",
1263+
/* open llama FP32_INT4 */
1264+
};
1265+
1266+
if (_impl_params->desc->type_string() == "paged_attention" ||
1267+
_impl_params->desc->type_string() == "softmax" ||
1268+
_impl_params->desc->type_string() == "gemm" ||
1269+
_impl_params->desc->type_string() == "eltwise" ||
1270+
_impl_params->desc->type_string() == "add" ||
1271+
_impl_params->desc->type_string() == "transpose")
1272+
print_ids.push_back(id());
1273+
1274+
if (std::find(print_ids.begin(), print_ids.end(), id()) != print_ids.end() && get_network().get_config().get_property(ov::enable_profiling)) {
1275+
GPU_DEBUG_INFO << "Execute " << id() << " (type: " << _impl_params->desc->type_string() << ") " << std::endl;
1276+
for (size_t i = 0; i < _deps.size(); ++i) {
1277+
GPU_DEBUG_INFO << "- inputs[" << i << "] : " << _deps[i].first->id() << " - " << _deps[i].first->get_output_layout(0).to_short_string() << std::endl;
1278+
}
1279+
}
12471280
bool need_args_update = false;
12481281
_mem_changed = false;
12491282
const auto orig_outputs = _outputs;
@@ -1400,14 +1433,15 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
14001433
GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::inference);
14011434
auto ev = _impl->execute(dependencies, *this);
14021435

1403-
GPU_DEBUG_IF(!debug_config->dump_profiling_data.empty()) {
1436+
if (std::find(print_ids.begin(), print_ids.end(), id()) != print_ids.end() && get_network().get_config().get_property(ov::enable_profiling)) {
14041437
get_network().get_stream().wait_for_events({ev});
14051438

14061439
if (ev != nullptr) {
14071440
auto profiling_info = ev->get_profiling_info();
14081441
for (const auto &interval : profiling_info) {
14091442
if (interval.stage == cldnn::instrumentation::profiling_stage::executing) {
1410-
GPU_DEBUG_CODE(stage_prof.set_custom_stage_duration(interval.value->value()));
1443+
auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(interval.value->value()).count();
1444+
GPU_DEBUG_INFO << id() << " performace time = " << time_res0 << " mcs\n";
14111445
}
14121446
}
14131447
}

src/plugins/intel_gpu/src/graph/program.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,9 @@ void program::transfer_memory_to_device() {
700700
auto& mem = data_node.get_attached_memory();
701701
auto mem_layout = mem.get_layout();
702702
auto alloc_type = mem.get_allocation_type();
703+
if (ov::shape_size(mem_layout.get_shape()) == 0)
704+
continue;
705+
GPU_DEBUG_TRACE_DETAIL << "mem_layout: " << mem_layout.to_short_string() << " data: " << data_node_layout.to_short_string() << "\n";
703706
if (!mem_layout.compatible(data_node_layout)) {
704707
std::string err_str("Node and memory layouts are incompatible, error occurred for " + node->id() + " node");
705708
throw std::invalid_argument(err_str);

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+9
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ KERNEL(pa_kv_cache_update)(
3636
// printf("Update value %d. %d (%f)\n", out_offset, in_offset, value_data[in_offset]);
3737
// }
3838

39+
// if (batch_idx == 0 && hidden_idx == 0) {
40+
// printf("Update value slot for %d = %d\n", seq_idx, slot_idx);
41+
// }
42+
3943
value_cache_data[out_offset] = value_data[in_offset];
4044
#else
4145
const uint head_size_outer_block = hidden_idx / X_BLOCK_SIZE;
@@ -49,6 +53,11 @@ KERNEL(pa_kv_cache_update)(
4953
// printf("Update key_cache %d. %d (%f); seq_idx=%d, hidden_idx=%d, slot_idx=%d, block_index=%d, block_offset=%d; block_elem_num=%d\n", out_offset, in_offset, key_data[in_offset],
5054
// seq_idx, hidden_idx, slot_idx, block_index, block_offset, block_elem_num);
5155
// }
56+
57+
// if (batch_idx == 0 && hidden_idx == 0) {
58+
// printf("Update key slot for %d = %d\n", seq_idx, slot_idx);
59+
// }
60+
5261
key_cache_data[out_offset] = key_data[in_offset];
5362
#endif
5463
}

0 commit comments

Comments
 (0)