Skip to content

Commit 2140d77

Browse files
committed
[GPU] Fix shape infer optimziation of shape_of subgraphs in case of data input usage
1 parent 7faadea commit 2140d77

File tree

2 files changed

+183
-2
lines changed

2 files changed

+183
-2
lines changed

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ void primitive_inst::update_shape() {
368368
if (_node->is_in_shape_of_subgraph()) {
369369
bool subgraph_input_changed = false;
370370
for (size_t i = 0; i < dependant_shape_of_insts.size(); i++) {
371-
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED)) {
371+
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED) || dependant_shape_of_insts[i]->_node->is_type<input_layout>()) {
372372
subgraph_input_changed = true;
373373
break;
374374
}
@@ -396,6 +396,7 @@ void primitive_inst::update_shape() {
396396
const auto& insts = _deps[i].first->dependant_shape_of_insts;
397397
for (auto& inst : insts) {
398398
can_skip &= !inst->get_flag(ExecutionFlags::SHAPE_CHANGED);
399+
can_skip &= !inst->_node->is_type<input_layout>();
399400
}
400401
if (can_skip)
401402
continue;
@@ -1849,7 +1850,7 @@ void primitive_inst::prepare_primitive() {
18491850
if (_node->is_in_shape_of_subgraph() && dependant_shape_of_insts.front()->is_dynamic()) {
18501851
bool subgraph_input_changed = false;
18511852
for (size_t i = 0; i < dependant_shape_of_insts.size(); i++) {
1852-
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED)) {
1853+
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED) || dependant_shape_of_insts[i]->_node->is_type<input_layout>()) {
18531854
subgraph_input_changed = true;
18541855
break;
18551856
}

src/plugins/intel_gpu/tests/unit/dynamic_execution/update_shape_test.cpp

+180
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
//
44

55
#include "test_utils.h"
6+
#include "random_generator.hpp"
67

78
#include <intel_gpu/primitives/input_layout.hpp>
89
#include <intel_gpu/primitives/data.hpp>
910
#include <intel_gpu/primitives/shape_of.hpp>
1011
#include <intel_gpu/primitives/broadcast.hpp>
1112
#include <intel_gpu/primitives/gather.hpp>
1213
#include <intel_gpu/primitives/non_zero.hpp>
14+
#include <intel_gpu/primitives/paged_attention.hpp>
15+
#include <intel_gpu/primitives/gather.hpp>
1316

1417
#include "program_wrapper.h"
1518

@@ -65,4 +68,181 @@ TEST(update_shape_test, ocl_impl_in_shapeof_subgraph) {
6568
std::map<primitive_id, network_output> outputs;
6669
OV_ASSERT_NO_THROW(outputs = network.execute());
6770
}
71+
72+
TEST(update_shape_test, max_context_len_shapeof_subgraph) {
73+
tests::random_generator rg(GET_SUITE_NAME);
74+
auto& engine = get_test_engine();
75+
76+
auto input_data_layout = layout{ov::PartialShape{1, -1}, data_types::f16, format::bfyx};
77+
78+
auto qkv_mem_layout = layout{ov::PartialShape{1, 128}, data_types::f16, format::bfyx};
79+
auto qkv_mem = engine.allocate_memory(qkv_mem_layout);
80+
auto qkv_rnd = rg.generate_random_1d<ov::float16>(qkv_mem_layout.count(), 0, 10);
81+
set_values(qkv_mem, qkv_rnd);
82+
83+
auto key_cache_mem_layout = layout{ov::PartialShape{1, 2, 64, 16}, data_types::f16, format::bfyx};
84+
auto value_cache_mem_layout = layout{ov::PartialShape{1, 2, 16, 64}, data_types::f16, format::bfyx};
85+
auto key_cache_mem = engine.allocate_memory(key_cache_mem_layout);
86+
auto value_cache_mem = engine.allocate_memory(value_cache_mem_layout);
87+
auto cache_rnd = rg.generate_random_1d<ov::float16>(key_cache_mem_layout.count(), 0, 10);
88+
set_values(key_cache_mem, cache_rnd);
89+
set_values(value_cache_mem, cache_rnd);
90+
91+
auto past_lens_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
92+
auto past_lens_mem = engine.allocate_memory(past_lens_mem_layout);
93+
set_values(value_cache_mem, {8});
94+
95+
auto subsequence_begins_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx};
96+
auto subsequence_begins_mem = engine.allocate_memory(subsequence_begins_mem_layout);
97+
set_values(subsequence_begins_mem, {0, 1});
98+
99+
auto block_indices_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx};
100+
auto block_indices_mem = engine.allocate_memory(block_indices_mem_layout);
101+
set_values(block_indices_mem, {0});
102+
103+
auto block_indices_begins_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx};
104+
auto block_indices_begins_mem = engine.allocate_memory(block_indices_begins_mem_layout);
105+
set_values(block_indices_begins_mem, {0, 1});
106+
107+
auto scale_mem_layout = layout{ov::PartialShape{1}, data_types::f16, format::bfyx};
108+
auto scale_mem = engine.allocate_memory(scale_mem_layout);
109+
set_values<ov::float16>(scale_mem, {1});
110+
111+
auto sliding_window_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
112+
auto sliding_window_mem = engine.allocate_memory(sliding_window_mem_layout);
113+
set_values(sliding_window_mem, {0});
114+
115+
auto alibi_mem_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx};
116+
auto alibi_mem = engine.allocate_memory(alibi_mem_layout);
117+
118+
auto const_one_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
119+
auto const_one_mem = engine.allocate_memory(const_one_layout);
120+
set_values(const_one_mem, {1});
121+
122+
auto input_data_mem_layout = layout{ov::PartialShape{1, 9}, data_types::f16, format::bfyx};
123+
auto input_data_mem = engine.allocate_memory(input_data_mem_layout);
124+
auto input_data_rnd = rg.generate_random_1d<ov::float16>(input_data_mem_layout.count(), 0, 10);
125+
set_values(input_data_mem, input_data_rnd);
126+
127+
auto query_layout = layout{ov::PartialShape{-1, 128}, data_types::f16, format::bfyx};
128+
auto key_layout = query_layout;
129+
auto value_layout = query_layout;
130+
auto key_cache_layout = layout{ov::PartialShape{-1, 2, 64, 16}, data_types::f16, format::bfyx};
131+
auto dynamic_i32_layout = layout{ov::PartialShape::dynamic(1), data_types::i32, format::bfyx};
132+
auto value_cache_layout = key_cache_layout;
133+
auto past_lens_layout = dynamic_i32_layout;
134+
auto subsequence_begins_layout = dynamic_i32_layout;
135+
auto block_indices_layout = dynamic_i32_layout;
136+
auto block_indices_begins_layout = dynamic_i32_layout;
137+
auto scale_layout = layout{ov::PartialShape{1}, data_types::f16, format::bfyx};
138+
auto sliding_window_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
139+
auto alibi_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx};
140+
auto max_context_len_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
141+
142+
std::vector<input_info> pa_inputs = {input_info("query"),
143+
input_info("key"),
144+
input_info("value"),
145+
input_info("key_cache"),
146+
input_info("value_cache"),
147+
input_info("past_lens"),
148+
input_info("subsequence_begins"),
149+
input_info("block_indices"),
150+
input_info("block_indices_begins"),
151+
input_info("scale"),
152+
input_info("sliding_window"),
153+
input_info("alibi"),
154+
input_info("max_context_len")};
155+
156+
auto pa_prim = paged_attention("paged_attention", pa_inputs);
157+
pa_prim.head_size = 64;
158+
pa_prim.kv_heads_num = 2;
159+
pa_prim.heads_num = 2;
160+
pa_prim.scale_val = 1.f;
161+
pa_prim.has_alibi = false;
162+
pa_prim.num_outputs = 1;
163+
pa_prim.has_rotated_blocks = false;
164+
165+
topology topology;
166+
topology.add(input_layout("input_data", input_data_layout));
167+
topology.add(input_layout("query", query_layout));
168+
topology.add(input_layout("key", key_layout));
169+
topology.add(input_layout("value", value_layout));
170+
topology.add(input_layout("key_cache", key_cache_layout));
171+
topology.add(input_layout("value_cache", value_cache_layout));
172+
topology.add(input_layout("past_lens", past_lens_layout));
173+
topology.add(input_layout("subsequence_begins", subsequence_begins_layout));
174+
topology.add(input_layout("block_indices", block_indices_layout));
175+
topology.add(input_layout("block_indices_begins", block_indices_begins_layout));
176+
topology.add(input_layout("scale", scale_layout));
177+
topology.add(input_layout("sliding_window", sliding_window_layout));
178+
topology.add(input_layout("alibi", alibi_layout));
179+
topology.add(input_layout("max_context_len", max_context_len_layout));
180+
topology.add(data("const_one", const_one_mem));
181+
topology.add(shape_of("shape_of", input_info("input_data"), data_types::i32));
182+
topology.add(gather("gather", input_info("shape_of"), input_info("const_one"), 0, 1, ov::Shape{}));
183+
topology.add(broadcast("broadcast", input_info("gather"), input_info("max_context_len"), {}, ov::op::BroadcastType::BIDIRECTIONAL));
184+
topology.add(pa_prim);
185+
186+
ExecutionConfig config = get_test_default_config(engine);
187+
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
188+
189+
network network(engine, topology, config);
190+
191+
network.set_input_data("input_data", input_data_mem);
192+
network.set_input_data("query", qkv_mem);
193+
network.set_input_data("key", qkv_mem);
194+
network.set_input_data("value", qkv_mem);
195+
network.set_input_data("key_cache", key_cache_mem);
196+
network.set_input_data("value_cache", value_cache_mem);
197+
network.set_input_data("past_lens", past_lens_mem);
198+
network.set_input_data("subsequence_begins", subsequence_begins_mem);
199+
network.set_input_data("block_indices", block_indices_mem);
200+
network.set_input_data("block_indices_begins", block_indices_begins_mem);
201+
network.set_input_data("scale", scale_mem);
202+
network.set_input_data("sliding_window", sliding_window_mem);
203+
network.set_input_data("alibi", alibi_mem);
204+
205+
// Set original max_context_len value
206+
auto max_context_len_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
207+
auto max_context_len_mem = engine.allocate_memory(max_context_len_mem_layout);
208+
set_values(max_context_len_mem, {9});
209+
210+
network.set_input_data("max_context_len", max_context_len_mem);
211+
212+
// 1st network execution
213+
network.execute();
214+
215+
auto broadcast_inst = network.get_primitive("broadcast");
216+
ASSERT_EQ(broadcast_inst->get_node().get_dependant_shape_of_nodes().size(), 2);
217+
218+
// Verify broadcast shape after first execution
219+
auto broadcast_shape = broadcast_inst->get_impl_params()->get_output_layout().get_shape();
220+
ASSERT_EQ(broadcast_shape, ov::Shape{9});
221+
222+
network.set_input_data("input_data", input_data_mem);
223+
network.set_input_data("query", qkv_mem);
224+
network.set_input_data("key", qkv_mem);
225+
network.set_input_data("value", qkv_mem);
226+
network.set_input_data("key_cache", key_cache_mem);
227+
network.set_input_data("value_cache", value_cache_mem);
228+
network.set_input_data("past_lens", past_lens_mem);
229+
network.set_input_data("subsequence_begins", subsequence_begins_mem);
230+
network.set_input_data("block_indices", block_indices_mem);
231+
network.set_input_data("block_indices_begins", block_indices_begins_mem);
232+
network.set_input_data("scale", scale_mem);
233+
network.set_input_data("sliding_window", sliding_window_mem);
234+
network.set_input_data("alibi", alibi_mem);
235+
236+
// Update max_context_len value, which should be taken into account in shape recalculation for broadcast
237+
set_values(max_context_len_mem, {8});
238+
239+
network.set_input_data("max_context_len", max_context_len_mem);
240+
241+
// 2nd network execution with updated max_context_len
242+
network.execute();
243+
244+
// Check if broadcast shape was recalculated
245+
broadcast_shape = broadcast_inst->get_impl_params()->get_output_layout().get_shape();
246+
ASSERT_EQ(broadcast_shape, ov::Shape{8});
247+
}
68248
} // update_shape_test

0 commit comments

Comments
 (0)