Skip to content

Commit 8aa0991

Browse files
committed
WIP: [GPU] PagedAttention initial impl
1 parent 757caa1 commit 8aa0991

27 files changed

+1563
-15
lines changed

src/plugins/intel_gpu/include/intel_gpu/plugin/program_builder.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class ProgramBuilder final {
172172
void CreateSingleLayerPrimitive(cldnn::topology& topology, const std::shared_ptr<ov::Node>& op);
173173
};
174174

175+
void CreatePagedAttention(ProgramBuilder& p, const std::shared_ptr<ov::Node>& op);
175176
void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node, CustomLayerPtr customLayer);
176177
void CreateUnaryEltwiseOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& node,
177178
cldnn::activation_func func, cldnn::activation_additional_params params);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
#include "primitive.hpp"
7+
#include "intel_gpu/graph/program.hpp"
8+
9+
#include <vector>
10+
11+
namespace cldnn {
12+
13+
struct paged_attention : public primitive_base<paged_attention> {
14+
CLDNN_DECLARE_PRIMITIVE(paged_attention)
15+
16+
paged_attention() : primitive_base("", {}) {}
17+
18+
paged_attention(const primitive_id& id,
19+
const std::vector<input_info>& inputs,
20+
const padding& output_padding = padding())
21+
: primitive_base(id, inputs, {output_padding}) {
22+
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
23+
}
24+
25+
bool operator==(const primitive& rhs) const override {
26+
return compare_common_params(rhs);
27+
}
28+
29+
void save(BinaryOutputBuffer& ob) const override {
30+
primitive_base<paged_attention>::save(ob);
31+
}
32+
33+
void load(BinaryInputBuffer& ib) override {
34+
primitive_base<paged_attention>::load(ib);
35+
}
36+
37+
std::shared_ptr<cldnn::program> prefill_stage;
38+
};
39+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/graph_optimizer/compile_graph.cpp

+10-12
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ void compile_graph::run(program& p) {
3434
}
3535
}
3636

37-
auto task_executor = p.get_task_executor();
3837
auto& proc_order = p.get_processing_order();
39-
std::vector<ov::threading::Task> tasks;
40-
std::exception_ptr exception;
4138
for (size_t idx = 0; idx < proc_order.size(); idx++) {
4239
auto& node = *(std::next(proc_order.begin(), idx));
4340
const bool use_shape_agnostic_impl = !p.get_config().get_property(ov::intel_gpu::use_only_static_kernels_for_dynamic_shape);
@@ -79,29 +76,30 @@ void compile_graph::run(program& p) {
7976
can_select_impl = true;
8077

8178
if (can_select_impl) {
82-
tasks.push_back([node, &exception, change_initial_impl, original_impl_type] {
8379
try {
80+
std::exception_ptr curr_excp;
8481
node->selected_impl = node->type()->choose_impl(*node);
8582
if (change_initial_impl) {
8683
GPU_DEBUG_TRACE_DETAIL << node->id() << ": use " << node->get_preferred_impl_type()
8784
<< " as initial impl instead of " << original_impl_type << std::endl;
8885
node->set_preferred_impl_type(original_impl_type);
8986
}
9087
} catch(...) {
91-
exception = std::current_exception();
88+
try {
89+
std::exception_ptr curr_excp;
90+
if (curr_excp = std::current_exception())
91+
{
92+
std::rethrow_exception(curr_excp);
93+
}
94+
} catch (const std::exception& e) {
95+
std::cerr << "Can't compile " << node->id() << ", error " << e.what() << "\n";
96+
}
9297
}
93-
});
9498
} else {
9599
if (change_initial_impl) {
96100
node->set_preferred_impl_type(original_impl_type);
97101
}
98102
}
99103
}
100104

101-
task_executor->run_and_wait(tasks);
102-
tasks.clear();
103-
104-
if (exception) {
105-
std::rethrow_exception(exception);
106-
}
107105
}

src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void mark_runtime_skippable_nodes::run(program& p) {
4040
}
4141
});
4242
program_helpers::do_for_types<permute>(*node, [](permute_node& node){
43+
return;
4344
auto impl_params = node.get_kernel_impl_params();
4445
if (node.is_output() ||
4546
node.has_fused_primitives() ||

0 commit comments

Comments
 (0)