Skip to content

Commit c710343

Browse files
[TRANSFORMATIONS] Fix implicit conversion of ov::Node with multiple outputs to ov::Output (#29780)
### Details: In the IndirectSDPAOpt transformation, 1st input to the KVCache node may appear to be a VariadicSplit node with multiple outputs. Using the default output of the node is incorrect and leads to an exception if used as an input to another node because of an implicit conversion to ov::Output. Use the output of the VariadicSplit node explicitly to avoid the exception. ### Tickets: - [CVS-163937](https://jira.devtools.intel.com/browse/CVS-163937) Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 9b29ce3 commit c710343

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ IndirectSDPAOpt::IndirectSDPAOpt() {
160160
ov::replace_node(gather_node_1, gather_input_node_1);
161161

162162
auto indirect_kv_cache_0 = std::make_shared<op::KVCache>(gather_input_node_0,
163-
kv_cache_node_0->get_input_node_shared_ptr(1),
163+
kv_cache_node_0->input_value(1),
164164
beam_idx_node,
165165
kv_cache_node_0->get_variable(),
166166
kv_cache_node_0->get_concat_axis(),
167167
gather_axis_0,
168168
kv_cache_node_0->get_output_element_type(0));
169169

170170
auto indirect_kv_cache_1 = std::make_shared<op::KVCache>(gather_input_node_1,
171-
kv_cache_node_1->get_input_node_shared_ptr(1),
171+
kv_cache_node_1->input_value(1),
172172
beam_idx_node,
173173
kv_cache_node_1->get_variable(),
174174
kv_cache_node_1->get_concat_axis(),

src/plugins/intel_gpu/tests/unit/transformations/indirect_kv_cache_test.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "openvino/op/parameter.hpp"
1616
#include "openvino/op/result.hpp"
1717
#include "openvino/op/gather.hpp"
18+
#include "openvino/op/variadic_split.hpp"
1819
#include "openvino/pass/manager.hpp"
1920

2021
#include <transformations/utils/utils.hpp>
@@ -182,3 +183,66 @@ TEST_F(TransformationTestsF, IndirectKVCache4) {
182183
comparator.enable(FunctionsComparator::ATTRIBUTES);
183184
}
184185
}
186+
187+
TEST_F(TransformationTestsF, IndirectKVCache5) {
188+
std::vector<int64_t> in0_order = {0, 1, 2, 3};
189+
std::vector<int64_t> in1_order = {0, 1, 2, 3};
190+
std::vector<int64_t> in2_order = {0, 1, 2, 3};
191+
std::vector<int64_t> out_order = {0, 1, 2, 3};
192+
const bool is_causal = false;
193+
{
194+
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1});
195+
auto key_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1, 32, -1, 80}, ov::element::f16, "v0"});
196+
auto value_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1, 32, -1, 80}, ov::element::f16, "v1"});
197+
auto key_past = std::make_shared<ov::intel_gpu::op::ReadValue>(key_variable);
198+
auto value_past = std::make_shared<ov::intel_gpu::op::ReadValue>(value_variable);
199+
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 0);
200+
auto key_gather_past = std::make_shared<ov::op::v8::Gather>(key_past, beam_idx, axis);
201+
auto value_gather_past = std::make_shared<ov::op::v8::Gather>(value_past, beam_idx, axis);
202+
203+
auto key_data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 240});
204+
auto vs_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
205+
auto split_lengths = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int64_t>{80, 80, -1});
206+
auto var_split = std::make_shared<ov::op::v1::VariadicSplit>(key_data, vs_axis, split_lengths);
207+
auto parameter_value = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 80});
208+
auto key_cache = std::make_shared<ov::intel_gpu::op::KVCache>(key_gather_past, var_split->output(0), key_variable, 0, ov::element::f16);
209+
auto value_cache = std::make_shared<ov::intel_gpu::op::KVCache>(value_gather_past, parameter_value, value_variable, 0, ov::element::f16);
210+
211+
auto sdpa_q = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 80});
212+
auto attn_mask = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 1, -1, -1});
213+
auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{});
214+
auto inputs = ov::OutputVector{sdpa_q, key_cache, value_cache, attn_mask, scale};
215+
auto sdpa = std::make_shared<ov::intel_gpu::op::SDPA>(inputs, is_causal, in0_order, in1_order, in2_order, out_order);
216+
auto result = std::make_shared<ov::op::v0::Result>(sdpa);
217+
218+
model = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{key_data, parameter_value, beam_idx, sdpa_q, attn_mask, scale});
219+
manager.register_pass<IndirectKVCache>();
220+
}
221+
{
222+
auto indirect_axis = 0;
223+
auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1});
224+
auto key_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1, 32, -1, 80}, ov::element::f16, "v0"});
225+
auto value_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1, 32, -1, 80}, ov::element::f16, "v1"});
226+
auto key_past = std::make_shared<ov::intel_gpu::op::ReadValue>(key_variable);
227+
auto value_past = std::make_shared<ov::intel_gpu::op::ReadValue>(value_variable);
228+
229+
auto key_data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 240});
230+
auto vs_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
231+
auto split_lengths = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int64_t>{80, 80, -1});
232+
auto var_split = std::make_shared<ov::op::v1::VariadicSplit>(key_data, vs_axis, split_lengths);
233+
auto parameter_value = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 80});
234+
auto key_cache = std::make_shared<ov::intel_gpu::op::KVCache>(key_past, var_split->output(0), beam_idx, key_variable, 0, 0, ov::element::f16);
235+
auto value_cache = std::make_shared<ov::intel_gpu::op::KVCache>(value_past, parameter_value, beam_idx, key_variable, 0, 0, ov::element::f16);
236+
237+
auto sdpa_q = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 32, -1, 80});
238+
auto attn_mask = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1, 1, -1, -1});
239+
auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{});
240+
auto inputs = ov::OutputVector{sdpa_q, key_cache, value_cache, attn_mask, scale};
241+
242+
auto sdpa = std::make_shared<ov::intel_gpu::op::IndirectSDPA>(inputs, key_cache->output(1), is_causal, indirect_axis, in0_order, in1_order, in2_order, out_order);
243+
auto result = std::make_shared<ov::op::v0::Result>(sdpa);
244+
245+
model_ref = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{key_data, parameter_value, beam_idx, sdpa_q, attn_mask, scale});
246+
comparator.enable(FunctionsComparator::ATTRIBUTES);
247+
}
248+
}

0 commit comments

Comments
 (0)