Skip to content

Commit 570a15d

Browse files
luo-cheng2021MirceaDan99
authored andcommitted
[CPU] Fuse SDPA and Concat as early as possible (openvinotoolkit#28189)
### Details: - *Move StatefulSDPAFusion before CommonOptimizations* - *...* ### Tickets: - *[158738](https://jira.devtools.intel.com/browse/CVS-158738)*
1 parent 8a02f9a commit 570a15d

File tree

7 files changed

+152
-51
lines changed

7 files changed

+152
-51
lines changed

src/common/transformations/include/transformations/utils/gen_pattern.hpp

+82-15
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ namespace gen_pattern {
4040

4141
#ifdef CPU_DEBUG_CAPS
4242

43+
# ifdef __GNUC__
44+
# define CURRENT_LINE_NO __builtin_LINE()
45+
# define CURRENT_FILE __builtin_FILE()
46+
# else
47+
# define CURRENT_LINE_NO -1
48+
# define CURRENT_FILE ""
49+
# endif
50+
4351
template <typename... Args>
4452
static inline void _verbose_log(Args&&... args) {
4553
std::stringstream ss;
@@ -58,6 +66,10 @@ static bool matcher_verbose_enabled() {
5866
if (matcher_verbose_enabled()) \
5967
_verbose_log(__VA_ARGS__)
6068
#else
69+
70+
# define CURRENT_LINE_NO -1
71+
# define CURRENT_FILE ""
72+
6173
static bool matcher_verbose_enabled() {
6274
return false;
6375
}
@@ -181,6 +193,8 @@ class Symbol {
181193
double literal_const_value;
182194
std::shared_ptr<Entity> lhs;
183195
std::shared_ptr<Entity> rhs;
196+
const char* filename = "";
197+
int line_no = -1;
184198
// _,+,-,*,/
185199
// l : literal const
186200
// n : named symbol
@@ -220,10 +234,12 @@ class Symbol {
220234
entity->op = 'n';
221235
entity->name = name;
222236
}
223-
Symbol(const int value) {
237+
Symbol(const int value, int line_no = CURRENT_LINE_NO, const char* file = CURRENT_FILE) {
224238
entity = std::make_shared<Entity>();
225239
entity->op = 'l';
226240
entity->literal_const_value = value;
241+
entity->line_no = line_no;
242+
entity->filename = file;
227243
}
228244
Symbol(char op, const Symbol& lhs, const Symbol& rhs) {
229245
entity = std::make_shared<Entity>();
@@ -246,8 +262,12 @@ class Symbol {
246262
void* get_id() const {
247263
return entity.get();
248264
}
249-
const char* get_name() const {
250-
return entity->name;
265+
std::string get_name() const {
266+
if (entity->line_no == -1 || is_independent_var())
267+
return entity->name;
268+
auto filename = strrchr(entity->filename, '/') ? strrchr(entity->filename, '/') + 1 : entity->filename;
269+
std::string name(filename); // use filename:lineno instead
270+
return name + ":" + std::to_string(entity->line_no);
251271
}
252272
bool operator<(const Symbol& rhs) const {
253273
return get_id() < rhs.get_id();
@@ -739,7 +759,9 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
739759
explicit GenericPattern(const DiscreteTypeInfo& type_info,
740760
const OutputVector& args,
741761
const detail::AttrMap& attrs,
742-
const char* vt)
762+
const char* vt,
763+
const int line_no = -1,
764+
const char* file = "")
743765
: ov::pass::pattern::op::Pattern(args),
744766
m_type_info(type_info),
745767
m_attrs(attrs),
@@ -758,6 +780,12 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
758780
sep = ",";
759781
}
760782
ss << ")";
783+
if (line_no != -1) {
784+
// add the code line no to the log:
785+
// O P752<opset1::Multiply>(P736,P745)@fuse_rotary_positional_embeddings.cpp:551 vs ...
786+
auto filename = strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
787+
ss << "@" << filename << ":" << line_no;
788+
}
761789
m_signature = ss.str();
762790
set_friendly_name(std::string("P") + std::to_string(id));
763791
}
@@ -776,7 +804,13 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
776804
// strictly requires pattern & graph value to come from output port with same index,
777805
// this is absolute necessary when pattern contains split node connections.
778806
if (pattern_value.get_index() != graph_value.get_index()) {
779-
_VERBOSE_LOG(level, "X output index mismatch: ", pattern_value.get_index(), "!=", graph_value.get_index());
807+
_VERBOSE_LOG(level,
808+
"X output index mismatch:(",
809+
m_signature,
810+
"): ",
811+
pattern_value.get_index(),
812+
"!=",
813+
graph_value.get_index());
780814
return false;
781815
}
782816

@@ -1018,15 +1052,18 @@ template <class T>
10181052
std::shared_ptr<Node> makePattern(const std::vector<detail::PatternNode>& inputs,
10191053
detail::AttrMap attrmap = {},
10201054
const char* vt = nullptr,
1021-
const char* friendly_name = nullptr) {
1055+
const char* friendly_name = nullptr,
1056+
int line_no = CURRENT_LINE_NO,
1057+
const char* file = CURRENT_FILE) {
10221058
OutputVector args;
10231059
for (auto& in : inputs)
10241060
args.push_back(in.get_output());
10251061

10261062
// pattern nodes are better for pattern matching because
10271063
// - it can be generic/incomplete, so normal OP node is not working properly
10281064
// - it has predicate to correctly decide which branch to take (in Or pattern)
1029-
auto pattern_node = std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt);
1065+
auto pattern_node =
1066+
std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt, line_no, file);
10301067

10311068
if (friendly_name)
10321069
pattern_node->set_friendly_name(friendly_name);
@@ -1120,7 +1157,9 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
11201157
detail::PatternNode start,
11211158
detail::PatternNode stop,
11221159
detail::PatternNode step,
1123-
size_t axis) {
1160+
size_t axis,
1161+
int line_no = CURRENT_LINE_NO,
1162+
const char* file = CURRENT_FILE) {
11241163
std::vector<int64_t> begin_mask(axis + 1, 1);
11251164
std::vector<int64_t> end_mask(axis + 1, 1);
11261165
std::vector<int64_t> new_axis_mask;
@@ -1135,12 +1174,27 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
11351174
{"end_mask", end_mask},
11361175
{"new_axis_mask", new_axis_mask},
11371176
{"shrink_axis_mask", shrink_axis_mask},
1138-
{"ellipsis_mask", ellipsis_mask}});
1177+
{"ellipsis_mask", ellipsis_mask}},
1178+
nullptr,
1179+
nullptr,
1180+
line_no,
1181+
file);
11391182
return opt2;
11401183
}
11411184

1142-
inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) {
1143-
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}});
1185+
inline std::shared_ptr<Node> GenSlice(detail::PatternNode data,
1186+
Symbol start,
1187+
Symbol stop,
1188+
Symbol step,
1189+
size_t axis,
1190+
int line_no = CURRENT_LINE_NO,
1191+
const char* file = CURRENT_FILE) {
1192+
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}},
1193+
{},
1194+
nullptr,
1195+
nullptr,
1196+
line_no,
1197+
file);
11441198

11451199
std::vector<Symbol> vbegin(axis + 1, Symbol(0));
11461200
std::vector<Symbol> vend(axis + 1, Symbol(0));
@@ -1168,7 +1222,11 @@ inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Sy
11681222
{"end_mask", end_mask},
11691223
{"new_axis_mask", new_axis_mask},
11701224
{"shrink_axis_mask", shrink_axis_mask},
1171-
{"ellipsis_mask", ellipsis_mask}});
1225+
{"ellipsis_mask", ellipsis_mask}},
1226+
nullptr,
1227+
nullptr,
1228+
line_no,
1229+
file);
11721230
return opt1 | opt2;
11731231
}
11741232

@@ -1329,7 +1387,9 @@ class PatternValidator {
13291387
auto id = sym.get_id();
13301388
if (symbol_value_map.count(id)) {
13311389
if (symbol_value_map[id] != value) {
1332-
_VERBOSE_LOG(" in-consistency between multiple references of same symbol : ",
1390+
_VERBOSE_LOG(" in-consistency between multiple references of same symbol(",
1391+
sym.get_name(),
1392+
"): ",
13331393
symbol_value_map[id],
13341394
" != ",
13351395
value);
@@ -1345,7 +1405,12 @@ class PatternValidator {
13451405
if (sym.is_literal_const()) {
13461406
auto literal = sym.eval(symbol_value_map);
13471407
if (literal != value) {
1348-
_VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value);
1408+
_VERBOSE_LOG(" mismatch between literal symbol & value(",
1409+
sym.get_name(),
1410+
"): ",
1411+
literal,
1412+
" != ",
1413+
value);
13491414
return false;
13501415
}
13511416
// no need to put literal into value map to eval them.
@@ -1373,7 +1438,9 @@ class PatternValidator {
13731438
}
13741439
}
13751440
if (!is_match) {
1376-
_VERBOSE_LOG(" mismatch between derived & value : ",
1441+
_VERBOSE_LOG(" mismatch between derived & value(",
1442+
sym.get_name(),
1443+
"): ",
13771444
std::setprecision(std::numeric_limits<float>::max_digits10),
13781445
derived,
13791446
" != ",

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp

+42-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <openvino/opsets/opset13.hpp>
1313
#include <openvino/opsets/opset6.hpp>
1414
#include <openvino/opsets/opset8.hpp>
15+
#include <openvino/pass/manager.hpp>
1516
#include <openvino/pass/pattern/op/or.hpp>
1617
#include <openvino/pass/pattern/op/wrap_type.hpp>
1718
#include <transformations/utils/gen_pattern.hpp>
@@ -20,7 +21,12 @@
2021
#include "itt.hpp"
2122
#include "openvino/opsets/opset1.hpp"
2223
#include "ov_ops/type_relaxed.hpp"
24+
#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp"
2325
#include "transformations/cpu_opset/common/op/sdpa.hpp"
26+
#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"
27+
#include "transformations/defs.hpp"
28+
#include "transformations/op_conversions/convert_broadcast3.hpp"
29+
#include "transformations/transpose_sinking/ts_shape_of.hpp"
2430
using namespace ov::gen_pattern;
2531

2632
namespace ov {
@@ -56,8 +62,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
5662
std::shared_ptr<Node> reshape_k, reshape_v, unsqueeze_k, unsqueeze_v;
5763
std::shared_ptr<Node> computed_bcst_k, computed_bcst_v, multiply_k, multiply_v;
5864
std::shared_ptr<Node> mq_reshape_k, mq_reshape_v;
65+
std::shared_ptr<Node> computed_bcst3_k, computed_bcst3_v;
5966
auto multi_query_bcst = [](const std::shared_ptr<Node>& kv) {
60-
auto reshape_kv = wrap_type<opset6::Reshape>({kv, any_input()});
67+
auto reshape_kv = makePattern<opset6::Reshape>({kv, any_input()});
6168
auto unsqueeze_kv = makePattern<opset1::Unsqueeze>({kv, any_input()});
6269

6370
auto check_one = [](Output<Node> output) -> bool {
@@ -73,13 +80,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
7380
makePattern<opset1::Broadcast>({wrap_type<opset1::Constant>(check_one), any_input(), any_input()},
7481
{{"mode", "numpy"}});
7582

76-
auto multiply_kv = wrap_type<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
77-
auto result = wrap_type<opset6::Reshape>({multiply_kv, any_input()});
78-
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv);
83+
auto multiply_kv = makePattern<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
84+
auto computed_bcst3 = makePattern<opset3::Broadcast>({unsqueeze_kv, any_input()}, {{"mode", "bidirectional"}});
85+
86+
auto result = makePattern<opset6::Reshape>({multiply_kv | computed_bcst3, any_input()});
87+
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv, computed_bcst3);
7988
};
8089

81-
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k) = multi_query_bcst(concat_k);
82-
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v) = multi_query_bcst(concat_v);
90+
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) =
91+
multi_query_bcst(concat_k);
92+
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v, computed_bcst3_v) =
93+
multi_query_bcst(concat_v);
8394
auto present_k = concat_k | mq_reshape_k;
8495
auto present_v = concat_v | mq_reshape_v;
8596

@@ -178,15 +189,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
178189

179190
opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
180191
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
181-
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
192+
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) {
182193
return false;
183-
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
194+
}
195+
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) {
184196
return false;
197+
}
185198

186-
if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node))
199+
if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) {
187200
return false;
188-
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
201+
}
202+
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) {
189203
return false;
204+
}
190205

191206
auto is_optional_one_child = [&pattern_map](const std::vector<std::shared_ptr<Node>>& nodes) {
192207
for (auto&& node : nodes) {
@@ -212,7 +227,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
212227
computed_bcst_v,
213228
multiply_v,
214229
mq_reshape_k,
215-
mq_reshape_v})) {
230+
mq_reshape_v,
231+
computed_bcst3_k,
232+
computed_bcst3_v})) {
216233
return false;
217234
}
218235

@@ -284,5 +301,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
284301
this->register_matcher(m, callback);
285302
}
286303

304+
bool SDPASubgraphFusion::run_on_model(const std::shared_ptr<ov::Model>& f) {
305+
RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion);
306+
ov::pass::Manager manager("SDPASubgraphFusion");
307+
308+
CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf);
309+
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward);
310+
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);
311+
// TODO: remove the following after snippets support patterns with dynamic shapes
312+
CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape);
313+
314+
manager.run_passes(f);
315+
return false;
316+
}
317+
287318
} // namespace intel_cpu
288319
} // namespace ov

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,12 @@ class StatefulSDPAFusion : public ov::pass::MatcherPass {
1414
StatefulSDPAFusion();
1515
};
1616

17+
class SDPASubgraphFusion : public ov::pass::ModelPass {
18+
public:
19+
OPENVINO_RTTI("SDPASubgraphFusion", "0");
20+
21+
bool run_on_model(const std::shared_ptr<ov::Model>& f) override;
22+
};
23+
1724
} // namespace intel_cpu
1825
} // namespace ov

src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
* Description: SDPA fuse transpose and reshape.
1919
* Original pattern Fused pattern
2020
*
21-
* input1 input2 input3
21+
* input1 readvalue readvalue
2222
* | | |
2323
* q_reshape k_reshape v_reshap
2424
* | | | (qkv transpose and reshape's orders)
25-
* q_transpose k_transpose v_transpose |
26-
* \ | / input1 input2 input3 |
27-
* \ | / \ | / /
25+
* q_transpose k_transpose v_transpose |
26+
* \ | / input1 ReadValue ReadValue |
27+
* \ | / \ | / /
2828
* ScaledDotProductAttention ---------> SDPAWithTransposeReshape
2929
* | |
3030
* out_transpose |
@@ -41,8 +41,8 @@ intel_cpu::SDPAFuseTransposeReshape::SDPAFuseTransposeReshape() {
4141
MATCHER_SCOPE(SDPAFuseTransposeReshape);
4242

4343
auto q_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
44-
auto k_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
45-
auto v_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
44+
auto k_reshape_node = wrap_type<op::v1::Reshape>({wrap_type<op::v6::ReadValue>(), any_input()});
45+
auto v_reshape_node = wrap_type<op::v1::Reshape>({wrap_type<op::v6::ReadValue>(), any_input()});
4646

4747
auto q_transpose_order_node = wrap_type<op::v0::Constant>();
4848
auto k_transpose_order_node = wrap_type<op::v0::Constant>();

0 commit comments

Comments
 (0)