Skip to content

Commit df6a258

Browse files
authored
[GPU] Added RoPE support for ChatGLM and Qwen (openvinotoolkit#24756)
### Details: - Added support RoPE for ChatGLM and Qwen models - Moved and refactored RoPE functional tests ### Tickets: - *[119150](https://jira.devtools.intel.com/browse/CVS-119150)*
1 parent ba8d6c5 commit df6a258

File tree

28 files changed

+1475
-631
lines changed

28 files changed

+1475
-631
lines changed

src/common/transformations/include/transformations/utils/gen_pattern.hpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,8 @@ class PatternValidator {
12121212
return false;
12131213
}
12141214

1215-
if (ele_type == ov::element::i32 || ele_type == ov::element::f32 || ele_type == ov::element::i64) {
1215+
if (ele_type == ov::element::i32 || ele_type == ov::element::i64 || ele_type == ov::element::f16 ||
1216+
ele_type == ov::element::f32) {
12161217
auto observed = constop->cast_vector<double>();
12171218
for (size_t i = 0; i < symbols.size(); i++)
12181219
detail::add_symbol_observed(sov, symbols[i], observed[i]);
@@ -1259,6 +1260,15 @@ class PatternValidator {
12591260
}
12601261
}
12611262

1263+
if (pconst_node->get_output_element_type(0).is_real() &&
1264+
vconst_node->get_output_element_type(0).is_real()) {
1265+
auto p_values = pconst_node->cast_vector<float>();
1266+
auto v_values = vconst_node->cast_vector<float>();
1267+
if (p_values == v_values) {
1268+
continue;
1269+
}
1270+
}
1271+
12621272
_VERBOSE_LOG("expecting Constant of type ",
12631273
pconst_node->get_output_element_type(0),
12641274
" but got ",

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "itt.hpp"
1111
#include "openvino/core/rt_info.hpp"
12+
#include "openvino/op/util/shape_of_base.hpp"
1213
#include "openvino/opsets/opset1.hpp"
1314
#include "openvino/opsets/opset6.hpp"
1415
#include "openvino/opsets/opset8.hpp"
@@ -415,9 +416,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
415416
ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {
416417
MATCHER_SCOPE(RoPEFusionChatGLM);
417418

418-
auto qkv_linear = makePattern("f32[?,?,?]"); // f32[seq_length, batch_size, 4608]
419+
auto qkv_linear = makePattern("[?,?,?]"); // [seq_length, batch_size, 4608]
419420
auto seq_length = makePattern("i32[1]");
420-
auto cos_sin_cache = makePattern("f32[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]
421+
auto cos_sin_cache = makePattern("[?,?,?,?]"); // [max_pos_embeddings, batch_size, 32, 2]
421422

422423
auto ndims = ov::gen_pattern::Symbol("ndims");
423424
auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
@@ -538,9 +539,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
538539
MATCHER_SCOPE(RoPEFusionQwen);
539540

540541
// rotary_emb_cos & rotary_emb_sin are sliced by present kv-length (past-kv-length + cur_len)
541-
auto rotary_emb_cos = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128]
542-
auto rotary_emb_sin = makePattern("f32[1,?,1,?]"); // [1,..4096,1,128]
543-
auto qkv_proj = makePattern("f32[?,?,?]"); // f32[?,?,12288]
542+
auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
543+
auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
544+
auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288]
544545

545546
auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
546547
auto head_size = ov::gen_pattern::Symbol("head_size");
@@ -559,8 +560,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
559560
auto Multiply_567524 = makePattern<opset1::Multiply>({ShapeOf_485735, {-1}}, {{"auto_broadcast", "numpy"}});
560561
auto Gather_377635 = makePattern<opset8::Gather>({Multiply_567524, {1}, 0}, {{"batch_dims", 0}});
561562

562-
auto input_ids = makePattern("i32[?,?]"); // [batch, length]
563-
auto ShapeOf_409241 = makePattern<opset1::ShapeOf>({input_ids}, {});
563+
auto input_ids = makePattern(); // [batch, length]
564+
auto ShapeOf_409241 = makePattern<ov::op::util::ShapeOfBase>({input_ids}, {});
564565
auto Gather_311651 = makePattern<opset8::Gather>({ShapeOf_409241, {1}, 0}, {{"batch_dims", 0}});
565566
auto neg_Multiply = makePattern<opset1::Multiply>({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}});
566567

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp

-621
This file was deleted.

src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ std::vector<std::string> disabledTestPatterns() {
371371
retVector.emplace_back(R"(smoke_VariableState/OVInferRequestVariableStateTest.*)");
372372
// Issue: 141705
373373
retVector.emplace_back(R"(.*smoke_arm_Deconv_2D_Planar_FP16/DeconvolutionLayerCPUTest.*INFERENCE_PRECISION_HINT=f16.*)");
374+
375+
retVector.emplace_back(R"(.*smoke_RoPETest.*)");
374376
#endif
375377

376378
#if defined(OPENVINO_ARCH_ARM)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "subgraph_tests/rotary_pos_emb.hpp"
6+
7+
namespace ov {
8+
namespace test {
9+
10+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2,
11+
RoPETestLlama2,
12+
::testing::Values(ov::test::utils::DEVICE_CPU),
13+
RoPETestLlama2::getTestCaseName);
14+
15+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
16+
RoPETestChatGLM,
17+
::testing::Values(ov::test::utils::DEVICE_CPU),
18+
RoPETestChatGLM::getTestCaseName);
19+
20+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b,
21+
RoPETestQwen7b,
22+
::testing::Combine(::testing::Values(true, false),
23+
::testing::Values(ov::test::utils::DEVICE_CPU)),
24+
RoPETestQwen7b::getTestCaseName);
25+
26+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJ,
27+
RoPETestGPTJ,
28+
::testing::Combine(::testing::Values(true, false),
29+
::testing::Values(ov::test::utils::DEVICE_CPU)),
30+
RoPETestGPTJ::getTestCaseName);
31+
} // namespace test
32+
} // namespace ov

src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,4 @@ REGISTER_FACTORY(internal, Convolution);
286286
REGISTER_FACTORY(internal, Placeholder);
287287
REGISTER_FACTORY(internal, SDPA);
288288
REGISTER_FACTORY(internal, IndirectSDPA);
289+
REGISTER_FACTORY(internal, RoPE);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
#include "primitive.hpp"
7+
#include "ov_ops/rotary_positional_embeddings.hpp"
8+
9+
namespace cldnn {
10+
using RoPE = ov::op::internal::RoPE;
11+
12+
/// @brief Rotary Position Embedding primitive
13+
struct rope : public primitive_base<rope> {
14+
CLDNN_DECLARE_PRIMITIVE(rope);
15+
16+
rope() : primitive_base("", {}) {}
17+
18+
/// @brief Constructs rope primitive
19+
/// @param id This primitive id
20+
/// @param inputs Inputs primitive ids
21+
/// @param config Specific RoPE config
22+
rope(const primitive_id& id,
23+
const std::vector<input_info>& inputs,
24+
const RoPE::Config& config,
25+
const padding& output_padding = padding())
26+
: primitive_base(id, inputs, {output_padding}),
27+
config(config) {}
28+
29+
RoPE::Config config;
30+
31+
size_t hash() const override {
32+
size_t seed = primitive::hash();
33+
seed = hash_combine(seed, config.gather_position_arg_id);
34+
seed = hash_combine(seed, config.head_cnt);
35+
seed = hash_combine(seed, config.head_size);
36+
seed = hash_combine(seed, config.input_trans0213);
37+
seed = hash_combine(seed, config.is_chatglm);
38+
seed = hash_combine(seed, config.is_interleaved);
39+
seed = hash_combine(seed, config.is_qwen);
40+
seed = hash_combine(seed, config.rotary_ndims);
41+
seed = hash_combine(seed, config.slice_start);
42+
seed = hash_combine(seed, config.slice_stop);
43+
return seed;
44+
}
45+
46+
bool operator==(const primitive& rhs) const override {
47+
if (!compare_common_params(rhs))
48+
return false;
49+
50+
auto rhs_casted = downcast<const rope>(rhs);
51+
52+
return config.gather_position_arg_id == rhs_casted.config.gather_position_arg_id &&
53+
config.head_cnt == rhs_casted.config.head_cnt &&
54+
config.head_size == rhs_casted.config.head_size &&
55+
config.input_trans0213 == rhs_casted.config.input_trans0213 &&
56+
config.is_chatglm == rhs_casted.config.is_chatglm &&
57+
config.is_interleaved == rhs_casted.config.is_interleaved &&
58+
config.is_qwen == rhs_casted.config.is_qwen &&
59+
config.rotary_ndims == rhs_casted.config.rotary_ndims &&
60+
config.slice_start == rhs_casted.config.slice_start &&
61+
config.slice_stop == rhs_casted.config.slice_stop;
62+
}
63+
64+
void save(BinaryOutputBuffer& ob) const override {
65+
primitive_base<rope>::save(ob);
66+
ob << config.gather_position_arg_id;
67+
ob << config.head_cnt;
68+
ob << config.head_size;
69+
ob << config.input_trans0213;
70+
ob << config.is_chatglm;
71+
ob << config.is_interleaved;
72+
ob << config.is_qwen;
73+
ob << config.rotary_ndims;
74+
ob << config.slice_start;
75+
ob << config.slice_stop;
76+
}
77+
78+
void load(BinaryInputBuffer& ib) override {
79+
primitive_base<rope>::load(ib);
80+
ib >> config.gather_position_arg_id;
81+
ib >> config.head_cnt;
82+
ib >> config.head_size;
83+
ib >> config.input_trans0213;
84+
ib >> config.is_chatglm;
85+
ib >> config.is_interleaved;
86+
ib >> config.is_qwen;
87+
ib >> config.rotary_ndims;
88+
ib >> config.slice_start;
89+
ib >> config.slice_stop;
90+
}
91+
};
92+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void register_implementations() {
9494
REGISTER_OCL(unique_count);
9595
REGISTER_OCL(unique_gather);
9696
REGISTER_OCL(scaled_dot_product_attention);
97+
REGISTER_OCL(rope);
9798
}
9899

99100
} // namespace ocl

src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#include "intel_gpu/primitives/unique.hpp"
7676
#include "intel_gpu/primitives/kv_cache.hpp"
7777
#include "intel_gpu/primitives/scaled_dot_product_attention.hpp"
78+
#include "intel_gpu/primitives/rope.hpp"
7879

7980
namespace cldnn {
8081
namespace ocl {
@@ -174,6 +175,7 @@ REGISTER_OCL(eye);
174175
REGISTER_OCL(unique_count);
175176
REGISTER_OCL(unique_gather);
176177
REGISTER_OCL(scaled_dot_product_attention);
178+
REGISTER_OCL(rope);
177179

178180
#undef REGISTER_OCL
179181

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "primitive_base.hpp"
6+
7+
#include "rope_inst.h"
8+
#include "rope/rope_kernel_selector.h"
9+
#include "rope/rope_kernel_ref.h"
10+
11+
namespace cldnn {
12+
namespace ocl {
13+
14+
struct rope_impl : typed_primitive_impl_ocl<rope> {
15+
using parent = typed_primitive_impl_ocl<rope>;
16+
using parent::parent;
17+
using kernel_selector_t = kernel_selector::rope_kernel_selector;
18+
using kernel_params_t = kernel_selector::rope_params;
19+
20+
DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::rope_impl);
21+
22+
std::unique_ptr<primitive_impl> clone() const override {
23+
return make_unique<rope_impl>(*this);
24+
}
25+
26+
void load(BinaryInputBuffer& ib) override {
27+
parent::load(ib);
28+
if (is_dynamic()) {
29+
auto& kernel_selector = kernel_selector_t::Instance();
30+
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
31+
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
32+
}
33+
}
34+
35+
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
36+
const auto& primitive = impl_param.typed_desc<rope>();
37+
auto params = get_default_params<kernel_selector::rope_params>(impl_param, is_shape_agnostic);
38+
39+
params.head_cnt = primitive->config.head_cnt;
40+
params.head_size = primitive->config.head_size;
41+
params.rotary_ndims = primitive->config.rotary_ndims;
42+
43+
params.slice_start = primitive->config.slice_start;
44+
params.slice_stop = primitive->config.slice_stop;
45+
46+
params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3;
47+
params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3;
48+
49+
params.is_qwen = primitive->config.is_qwen;
50+
params.is_chatglm = primitive->config.is_chatglm;
51+
52+
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
53+
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
54+
}
55+
return params;
56+
}
57+
58+
void update_dispatch_data(const kernel_impl_params& impl_param) override {
59+
auto kernel_params = get_kernel_params(impl_param, true);
60+
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
61+
}
62+
};
63+
64+
namespace detail {
65+
66+
attach_rope_impl::attach_rope_impl() {
67+
auto types = {
68+
data_types::f32,
69+
data_types::f16
70+
};
71+
72+
auto formats = {
73+
format::bfyx
74+
};
75+
76+
implementation_map<rope>::add(impl_types::ocl,
77+
shape_types::any,
78+
typed_primitive_impl_ocl<rope>::create<rope_impl>,
79+
types,
80+
formats);
81+
}
82+
83+
} // namespace detail
84+
} // namespace ocl
85+
} // namespace cldnn
86+
87+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::rope_impl)
88+
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::rope)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "intel_gpu/primitives/rope.hpp"
8+
#include "primitive_inst.h"
9+
10+
#include <string>
11+
12+
namespace cldnn {
13+
template <>
14+
struct typed_program_node<rope> : public typed_program_node_base<rope> {
15+
using parent = typed_program_node_base<rope>;
16+
17+
public:
18+
using parent::parent;
19+
20+
program_node& input(size_t idx = 0) const { return get_dependency(idx); }
21+
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
22+
};
23+
24+
using rope_node = typed_program_node<rope>;
25+
26+
template <>
27+
class typed_primitive_inst<rope> : public typed_primitive_inst_base<rope> {
28+
using parent = typed_primitive_inst_base<rope>;
29+
using parent::parent;
30+
31+
public:
32+
template<typename ShapeType>
33+
static std::vector<layout> calc_output_layouts(const rope_node& /*node*/, const kernel_impl_params& impl_param);
34+
static layout calc_output_layout(rope_node const& node, kernel_impl_params const& impl_param);
35+
static std::string to_string(rope_node const& node);
36+
};
37+
38+
using rope_inst = typed_primitive_inst<rope>;
39+
} // namespace cldnn

0 commit comments

Comments
 (0)