15
15
#include " openvino/op/parameter.hpp"
16
16
#include " openvino/op/result.hpp"
17
17
#include " openvino/op/gather.hpp"
18
+ #include " openvino/op/variadic_split.hpp"
18
19
#include " openvino/pass/manager.hpp"
19
20
20
21
#include < transformations/utils/utils.hpp>
@@ -182,3 +183,66 @@ TEST_F(TransformationTestsF, IndirectKVCache4) {
182
183
comparator.enable (FunctionsComparator::ATTRIBUTES);
183
184
}
184
185
}
186
+
187
+ TEST_F (TransformationTestsF, IndirectKVCache5) {
188
+ std::vector<int64_t > in0_order = {0 , 1 , 2 , 3 };
189
+ std::vector<int64_t > in1_order = {0 , 1 , 2 , 3 };
190
+ std::vector<int64_t > in2_order = {0 , 1 , 2 , 3 };
191
+ std::vector<int64_t > out_order = {0 , 1 , 2 , 3 };
192
+ const bool is_causal = false ;
193
+ {
194
+ auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 });
195
+ auto key_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1 , 32 , -1 , 80 }, ov::element::f16, " v0" });
196
+ auto value_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1 , 32 , -1 , 80 }, ov::element::f16, " v1" });
197
+ auto key_past = std::make_shared<ov::intel_gpu::op::ReadValue>(key_variable);
198
+ auto value_past = std::make_shared<ov::intel_gpu::op::ReadValue>(value_variable);
199
+ auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 0 );
200
+ auto key_gather_past = std::make_shared<ov::op::v8::Gather>(key_past, beam_idx, axis);
201
+ auto value_gather_past = std::make_shared<ov::op::v8::Gather>(value_past, beam_idx, axis);
202
+
203
+ auto key_data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 240 });
204
+ auto vs_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1 );
205
+ auto split_lengths = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3 }, std::vector<int64_t >{80 , 80 , -1 });
206
+ auto var_split = std::make_shared<ov::op::v1::VariadicSplit>(key_data, vs_axis, split_lengths);
207
+ auto parameter_value = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 80 });
208
+ auto key_cache = std::make_shared<ov::intel_gpu::op::KVCache>(key_gather_past, var_split->output (0 ), key_variable, 0 , ov::element::f16);
209
+ auto value_cache = std::make_shared<ov::intel_gpu::op::KVCache>(value_gather_past, parameter_value, value_variable, 0 , ov::element::f16);
210
+
211
+ auto sdpa_q = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 80 });
212
+ auto attn_mask = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 1 , -1 , -1 });
213
+ auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{});
214
+ auto inputs = ov::OutputVector{sdpa_q, key_cache, value_cache, attn_mask, scale};
215
+ auto sdpa = std::make_shared<ov::intel_gpu::op::SDPA>(inputs, is_causal, in0_order, in1_order, in2_order, out_order);
216
+ auto result = std::make_shared<ov::op::v0::Result>(sdpa);
217
+
218
+ model = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{key_data, parameter_value, beam_idx, sdpa_q, attn_mask, scale});
219
+ manager.register_pass <IndirectKVCache>();
220
+ }
221
+ {
222
+ auto indirect_axis = 0 ;
223
+ auto beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 });
224
+ auto key_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1 , 32 , -1 , 80 }, ov::element::f16, " v0" });
225
+ auto value_variable = std::make_shared<ov::op::util::Variable>(ov::op::util::VariableInfo{{-1 , 32 , -1 , 80 }, ov::element::f16, " v1" });
226
+ auto key_past = std::make_shared<ov::intel_gpu::op::ReadValue>(key_variable);
227
+ auto value_past = std::make_shared<ov::intel_gpu::op::ReadValue>(value_variable);
228
+
229
+ auto key_data = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 240 });
230
+ auto vs_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1 );
231
+ auto split_lengths = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3 }, std::vector<int64_t >{80 , 80 , -1 });
232
+ auto var_split = std::make_shared<ov::op::v1::VariadicSplit>(key_data, vs_axis, split_lengths);
233
+ auto parameter_value = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 80 });
234
+ auto key_cache = std::make_shared<ov::intel_gpu::op::KVCache>(key_past, var_split->output (0 ), beam_idx, key_variable, 0 , 0 , ov::element::f16);
235
+ auto value_cache = std::make_shared<ov::intel_gpu::op::KVCache>(value_past, parameter_value, beam_idx, key_variable, 0 , 0 , ov::element::f16);
236
+
237
+ auto sdpa_q = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 32 , -1 , 80 });
238
+ auto attn_mask = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{-1 , 1 , -1 , -1 });
239
+ auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{});
240
+ auto inputs = ov::OutputVector{sdpa_q, key_cache, value_cache, attn_mask, scale};
241
+
242
+ auto sdpa = std::make_shared<ov::intel_gpu::op::IndirectSDPA>(inputs, key_cache->output (1 ), is_causal, indirect_axis, in0_order, in1_order, in2_order, out_order);
243
+ auto result = std::make_shared<ov::op::v0::Result>(sdpa);
244
+
245
+ model_ref = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{key_data, parameter_value, beam_idx, sdpa_q, attn_mask, scale});
246
+ comparator.enable (FunctionsComparator::ATTRIBUTES);
247
+ }
248
+ }
0 commit comments