Skip to content

Commit a767301

Browse files
committed
[GPU] Extend shape_of subgraphs markup logic to include PagedAttention input
1 parent f671bfc commit a767301

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/mark_shape_of_subgraphs.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,34 @@
1010
#include "select_inst.h"
1111
#include "strided_slice_inst.h"
1212
#include "gather_inst.h"
13+
#include "input_layout_inst.h"
14+
#include "paged_attention_inst.h"
1315
#include "pass_manager.h"
1416

1517
#include "intel_gpu/graph/program.hpp"
1618

1719
using namespace cldnn;
1820

19-
void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
21+
static bool is_shape_of_subgraph_root(program_node& node) {
2022
if (node.is_type<shape_of>()) {
23+
return true;
24+
}
25+
26+
if (node.is_type<input_layout>()) {
27+
const auto& users = node.get_users();
28+
for (const auto& user : users) {
29+
const auto max_context_len_input_id = 12;
30+
if (user->is_type<paged_attention>() && user->get_dependency_index(node) == max_context_len_input_id) {
31+
return true;
32+
}
33+
}
34+
}
35+
36+
return false;
37+
}
38+
39+
void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
40+
if (is_shape_of_subgraph_root(node)) {
2141
mark_node(node);
2242
return;
2343
}
@@ -102,7 +122,7 @@ void mark_shape_of_subgraphs::mark_node(program_node& node) {
102122

103123
// If current node has shape_of type add it to dependant shape_of nodes for
104124
// correct dependency propagation for users
105-
if (node.is_type<shape_of>())
125+
if (is_shape_of_subgraph_root(node))
106126
node.add_dependant_shape_of_node(&node);
107127

108128
// Add parent shape_of nodes from other dependencies if there are any

src/plugins/intel_gpu/src/graph/program_node.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ void program_node::select_preferred_formats(impl_types impl_type) {
658658
}
659659

660660
void program_node::add_dependant_shape_of_node(const program_node* node) {
661-
OPENVINO_ASSERT(node->is_type<shape_of>(), "[GPU] Expected node type is shape_of");
661+
OPENVINO_ASSERT(node->is_type<shape_of>() || node->is_type<input_layout>(), "[GPU] Expected node type is shape_of");
662662
dependant_shape_of_nodes.insert(node);
663663
}
664664

0 commit comments

Comments
 (0)