9
9
10
10
#include " itt.hpp"
11
11
#include " openvino/core/rt_info.hpp"
12
+ #include " openvino/op/util/shape_of_base.hpp"
12
13
#include " openvino/opsets/opset1.hpp"
13
14
#include " openvino/opsets/opset6.hpp"
14
15
#include " openvino/opsets/opset8.hpp"
@@ -415,9 +416,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
415
416
ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM (int split_output_id) {
416
417
MATCHER_SCOPE (RoPEFusionChatGLM);
417
418
418
- auto qkv_linear = makePattern (" f32 [?,?,?]" ); // f32 [seq_length, batch_size, 4608]
419
+ auto qkv_linear = makePattern (" [?,?,?]" ); // [seq_length, batch_size, 4608]
419
420
auto seq_length = makePattern (" i32[1]" );
420
- auto cos_sin_cache = makePattern (" f32 [?,?,?,?]" ); // [max_pos_embeddings, batch_size, 32, 2]
421
+ auto cos_sin_cache = makePattern (" [?,?,?,?]" ); // [max_pos_embeddings, batch_size, 32, 2]
421
422
422
423
auto ndims = ov::gen_pattern::Symbol (" ndims" );
423
424
auto head_cnt = ov::gen_pattern::Symbol (" head_cnt" );
@@ -538,9 +539,9 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
538
539
MATCHER_SCOPE (RoPEFusionQwen);
539
540
540
541
// rotary_emb_cos & rotary_emb_sin are sliced by present kv-length (past-kv-length + cur_len)
541
- auto rotary_emb_cos = makePattern (" f32 [1,?,1,?]" ); // [1,..4096,1,128]
542
- auto rotary_emb_sin = makePattern (" f32 [1,?,1,?]" ); // [1,..4096,1,128]
543
- auto qkv_proj = makePattern (" f32 [?,?,?]" ); // f32 [?,?,12288]
542
+ auto rotary_emb_cos = makePattern (" [1,?,1,?]" ); // [1,..4096,1,128]
543
+ auto rotary_emb_sin = makePattern (" [1,?,1,?]" ); // [1,..4096,1,128]
544
+ auto qkv_proj = makePattern (" [?,?,?]" ); // [?,?,12288]
544
545
545
546
auto head_cnt = ov::gen_pattern::Symbol (" head_cnt" );
546
547
auto head_size = ov::gen_pattern::Symbol (" head_size" );
@@ -559,8 +560,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
559
560
auto Multiply_567524 = makePattern<opset1::Multiply>({ShapeOf_485735, {-1 }}, {{" auto_broadcast" , " numpy" }});
560
561
auto Gather_377635 = makePattern<opset8::Gather>({Multiply_567524, {1 }, 0 }, {{" batch_dims" , 0 }});
561
562
562
- auto input_ids = makePattern (" i32[?,?] " ); // [batch, length]
563
- auto ShapeOf_409241 = makePattern<opset1::ShapeOf >({input_ids}, {});
563
+ auto input_ids = makePattern (); // [batch, length]
564
+ auto ShapeOf_409241 = makePattern<ov::op::util::ShapeOfBase >({input_ids}, {});
564
565
auto Gather_311651 = makePattern<opset8::Gather>({ShapeOf_409241, {1 }, 0 }, {{" batch_dims" , 0 }});
565
566
auto neg_Multiply = makePattern<opset1::Multiply>({Gather_311651, {-1 }}, {{" auto_broadcast" , " numpy" }});
566
567
0 commit comments