File tree 2 files changed +23
-3
lines changed
src/plugins/intel_gpu/src/graph
2 files changed +23
-3
lines changed Original file line number Diff line number Diff line change 10
10
#include " select_inst.h"
11
11
#include " strided_slice_inst.h"
12
12
#include " gather_inst.h"
13
+ #include " input_layout_inst.h"
14
+ #include " paged_attention_inst.h"
13
15
#include " pass_manager.h"
14
16
15
17
#include " intel_gpu/graph/program.hpp"
16
18
17
19
using namespace cldnn ;
18
20
19
- void mark_shape_of_subgraphs::look_for_shape_of_subgraph (program_node& node) {
21
+ static bool is_shape_of_subgraph_root (program_node& node) {
20
22
if (node.is_type <shape_of>()) {
23
+ return true ;
24
+ }
25
+
26
+ if (node.is_type <input_layout>()) {
27
+ const auto & users = node.get_users ();
28
+ for (const auto & user : users) {
29
+ const auto max_context_len_input_id = 12 ;
30
+ if (user->is_type <paged_attention>() && user->get_dependency_index (node) == max_context_len_input_id) {
31
+ return true ;
32
+ }
33
+ }
34
+ }
35
+
36
+ return false ;
37
+ }
38
+
39
+ void mark_shape_of_subgraphs::look_for_shape_of_subgraph (program_node& node) {
40
+ if (is_shape_of_subgraph_root (node)) {
21
41
mark_node (node);
22
42
return ;
23
43
}
@@ -102,7 +122,7 @@ void mark_shape_of_subgraphs::mark_node(program_node& node) {
102
122
103
123
// If current node has shape_of type add it to dependant shape_of nodes for
104
124
// correct dependency propagation for users
105
- if (node. is_type <shape_of>( ))
125
+ if (is_shape_of_subgraph_root (node ))
106
126
node.add_dependant_shape_of_node (&node);
107
127
108
128
// Add parent shape_of nodes from other dependencies if there are any
Original file line number Diff line number Diff line change @@ -658,7 +658,7 @@ void program_node::select_preferred_formats(impl_types impl_type) {
658
658
}
659
659
660
660
void program_node::add_dependant_shape_of_node (const program_node* node) {
661
- OPENVINO_ASSERT (node->is_type <shape_of>(), " [GPU] Expected node type is shape_of" );
661
+ OPENVINO_ASSERT (node->is_type <shape_of>() || node-> is_type <input_layout>() , " [GPU] Expected node type is shape_of" );
662
662
dependant_shape_of_nodes.insert (node);
663
663
}
664
664
You can’t perform that action at this time.
0 commit comments