Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Fuse SDPA and Concat as early as possible #28189

Merged
Original file line number Diff line number Diff line change
@@ -40,6 +40,14 @@ namespace gen_pattern {

#ifdef CPU_DEBUG_CAPS

# ifdef __GNUC__
# define CURRENT_LINE_NO __builtin_LINE()
# define CURRENT_FILE __builtin_FILE()
# else
# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""
# endif

template <typename... Args>
static inline void _verbose_log(Args&&... args) {
std::stringstream ss;
@@ -58,6 +66,10 @@ static bool matcher_verbose_enabled() {
if (matcher_verbose_enabled()) \
_verbose_log(__VA_ARGS__)
#else

# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""

static bool matcher_verbose_enabled() {
return false;
}
@@ -181,6 +193,8 @@ class Symbol {
double literal_const_value;
std::shared_ptr<Entity> lhs;
std::shared_ptr<Entity> rhs;
const char* filename = "";
int line_no = -1;
// _,+,-,*,/
// l : literal const
// n : named symbol
@@ -220,10 +234,12 @@ class Symbol {
entity->op = 'n';
entity->name = name;
}
Symbol(const int value) {
Symbol(const int value, int line_no = CURRENT_LINE_NO, const char* file = CURRENT_FILE) {
entity = std::make_shared<Entity>();
entity->op = 'l';
entity->literal_const_value = value;
entity->line_no = line_no;
entity->filename = file;
}
Symbol(char op, const Symbol& lhs, const Symbol& rhs) {
entity = std::make_shared<Entity>();
@@ -246,8 +262,12 @@ class Symbol {
void* get_id() const {
return entity.get();
}
const char* get_name() const {
return entity->name;
std::string get_name() const {
if (entity->line_no == -1 || is_independent_var())
return entity->name;
auto filename = strrchr(entity->filename, '/') ? strrchr(entity->filename, '/') + 1 : entity->filename;
std::string name(filename); // use filename:lineno instead
return name + ":" + std::to_string(entity->line_no);
}
bool operator<(const Symbol& rhs) const {
return get_id() < rhs.get_id();
@@ -739,7 +759,9 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
explicit GenericPattern(const DiscreteTypeInfo& type_info,
const OutputVector& args,
const detail::AttrMap& attrs,
const char* vt)
const char* vt,
const int line_no = -1,
const char* file = "")
: ov::pass::pattern::op::Pattern(args),
m_type_info(type_info),
m_attrs(attrs),
@@ -758,6 +780,12 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
sep = ",";
}
ss << ")";
if (line_no != -1) {
// add the code line no to the log:
// O P752<opset1::Multiply>(P736,P745)@fuse_rotary_positional_embeddings.cpp:551 vs ...
auto filename = strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
ss << "@" << filename << ":" << line_no;
}
m_signature = ss.str();
set_friendly_name(std::string("P") + std::to_string(id));
}
@@ -776,7 +804,13 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
// strictly requires pattern & graph value to come from output port with same index,
// this is absolute necessary when pattern contains split node connections.
if (pattern_value.get_index() != graph_value.get_index()) {
_VERBOSE_LOG(level, "X output index mismatch: ", pattern_value.get_index(), "!=", graph_value.get_index());
_VERBOSE_LOG(level,
"X output index mismatch:(",
m_signature,
"): ",
pattern_value.get_index(),
"!=",
graph_value.get_index());
return false;
}

@@ -1018,15 +1052,18 @@ template <class T>
std::shared_ptr<Node> makePattern(const std::vector<detail::PatternNode>& inputs,
detail::AttrMap attrmap = {},
const char* vt = nullptr,
const char* friendly_name = nullptr) {
const char* friendly_name = nullptr,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
OutputVector args;
for (auto& in : inputs)
args.push_back(in.get_output());

// pattern nodes are better for pattern matching because
// - it can be generic/incomplete, so normal OP node is not working properly
// - it has predicate to correctly decide which branch to take (in Or pattern)
auto pattern_node = std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt);
auto pattern_node =
std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt, line_no, file);

if (friendly_name)
pattern_node->set_friendly_name(friendly_name);
@@ -1120,7 +1157,9 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
detail::PatternNode start,
detail::PatternNode stop,
detail::PatternNode step,
size_t axis) {
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
std::vector<int64_t> begin_mask(axis + 1, 1);
std::vector<int64_t> end_mask(axis + 1, 1);
std::vector<int64_t> new_axis_mask;
@@ -1135,12 +1174,27 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt2;
}

inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}});
inline std::shared_ptr<Node> GenSlice(detail::PatternNode data,
Symbol start,
Symbol stop,
Symbol step,
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}},
{},
nullptr,
nullptr,
line_no,
file);

std::vector<Symbol> vbegin(axis + 1, Symbol(0));
std::vector<Symbol> vend(axis + 1, Symbol(0));
@@ -1168,7 +1222,11 @@ inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Sy
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt1 | opt2;
}

@@ -1329,7 +1387,9 @@ class PatternValidator {
auto id = sym.get_id();
if (symbol_value_map.count(id)) {
if (symbol_value_map[id] != value) {
_VERBOSE_LOG(" in-consistency between multiple references of same symbol : ",
_VERBOSE_LOG(" in-consistency between multiple references of same symbol(",
sym.get_name(),
"): ",
symbol_value_map[id],
" != ",
value);
@@ -1345,7 +1405,12 @@ class PatternValidator {
if (sym.is_literal_const()) {
auto literal = sym.eval(symbol_value_map);
if (literal != value) {
_VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value);
_VERBOSE_LOG(" mismatch between literal symbol & value(",
sym.get_name(),
"): ",
literal,
" != ",
value);
return false;
}
// no need to put literal into value map to eval them.
@@ -1373,7 +1438,9 @@ class PatternValidator {
}
}
if (!is_match) {
_VERBOSE_LOG(" mismatch between derived & value : ",
_VERBOSE_LOG(" mismatch between derived & value(",
sym.get_name(),
"): ",
std::setprecision(std::numeric_limits<float>::max_digits10),
derived,
" != ",
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/gen_pattern.hpp>
@@ -20,7 +21,12 @@
#include "itt.hpp"
#include "openvino/opsets/opset1.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"
#include "transformations/defs.hpp"
#include "transformations/op_conversions/convert_broadcast3.hpp"
#include "transformations/transpose_sinking/ts_shape_of.hpp"
using namespace ov::gen_pattern;

namespace ov {
@@ -56,8 +62,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
std::shared_ptr<Node> reshape_k, reshape_v, unsqueeze_k, unsqueeze_v;
std::shared_ptr<Node> computed_bcst_k, computed_bcst_v, multiply_k, multiply_v;
std::shared_ptr<Node> mq_reshape_k, mq_reshape_v;
std::shared_ptr<Node> computed_bcst3_k, computed_bcst3_v;
auto multi_query_bcst = [](const std::shared_ptr<Node>& kv) {
auto reshape_kv = wrap_type<opset6::Reshape>({kv, any_input()});
auto reshape_kv = makePattern<opset6::Reshape>({kv, any_input()});
auto unsqueeze_kv = makePattern<opset1::Unsqueeze>({kv, any_input()});

auto check_one = [](Output<Node> output) -> bool {
@@ -73,13 +80,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
makePattern<opset1::Broadcast>({wrap_type<opset1::Constant>(check_one), any_input(), any_input()},
{{"mode", "numpy"}});

auto multiply_kv = wrap_type<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto result = wrap_type<opset6::Reshape>({multiply_kv, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv);
auto multiply_kv = makePattern<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto computed_bcst3 = makePattern<opset3::Broadcast>({unsqueeze_kv, any_input()}, {{"mode", "bidirectional"}});

auto result = makePattern<opset6::Reshape>({multiply_kv | computed_bcst3, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv, computed_bcst3);
};

std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k) = multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v) = multi_query_bcst(concat_v);
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) =
multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v, computed_bcst3_v) =
multi_query_bcst(concat_v);
auto present_k = concat_k | mq_reshape_k;
auto present_v = concat_v | mq_reshape_v;

@@ -178,15 +189,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) {
return false;
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
}
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) {
return false;
}

if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node))
if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) {
return false;
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
}
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) {
return false;
}

auto is_optional_one_child = [&pattern_map](const std::vector<std::shared_ptr<Node>>& nodes) {
for (auto&& node : nodes) {
@@ -212,7 +227,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
computed_bcst_v,
multiply_v,
mq_reshape_k,
mq_reshape_v})) {
mq_reshape_v,
computed_bcst3_k,
computed_bcst3_v})) {
return false;
}

@@ -284,5 +301,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
this->register_matcher(m, callback);
}

bool SDPASubgraphFusion::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion);
ov::pass::Manager manager("SDPASubgraphFusion");

CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);

manager.run_passes(f);
return false;
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -14,5 +14,12 @@ class StatefulSDPAFusion : public ov::pass::MatcherPass {
StatefulSDPAFusion();
};

class SDPASubgraphFusion : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("SDPASubgraphFusion", "0");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override;
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -431,6 +431,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
ov::pass::KeepConstAndDecompression);

CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
CPU_REGISTER_PASS_COMMON(manager, SDPASubgraphFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstsPrecision, decompression_precisions, false, true);
CPU_SET_CALLBACK_X64(
@@ -654,16 +655,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE);
CPU_SET_CALLBACK_COMMON(
manager,
[this](const_node_ptr& node) -> bool {
std::string errorMsg;
// Current SDPA impl is optimized only for LLM models, so we decompose it for others to avoid perf
// regression. Matching the pattern is a little complicated, so we just check if there is any state nodes.
return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg) &&
model->get_variables().size() > 0;
},
ov::pass::ScaledDotProductAttentionDecomposition);

// List of enabled/disabled transformations

@@ -945,9 +936,6 @@ void Transformations::PostLpt() {
}
#endif // OPENVINO_ARCH_X86_64

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);
CPU_SET_CALLBACK_X64(
Original file line number Diff line number Diff line change
@@ -152,9 +152,9 @@ class ConcatMultiQuerySDPTest : public testing::WithParamInterface<ConcatMultiQu
auto unsqueezeK = std::make_shared<ov::op::v0::Unsqueeze>(concatK, unsquezeAxis);
auto unsqueezeV = std::make_shared<ov::op::v0::Unsqueeze>(concatV, unsquezeAxis);

auto targetShape = ov::op::v0::Constant::create(qkvType, {1, 1, 1, 4, 1}, {1});
auto broadcastK = std::make_shared<ov::op::v1::Multiply>(unsqueezeK, targetShape);
auto broadcastV = std::make_shared<ov::op::v1::Multiply>(unsqueezeV, targetShape);
auto targetShape = ov::op::v0::Constant::create(element::i32, {5}, {1, 1, 1, 4, 1});
auto broadcastK = std::make_shared<ov::op::v3::Broadcast>(unsqueezeK, targetShape, op::BroadcastType::BIDIRECTIONAL);
auto broadcastV = std::make_shared<ov::op::v3::Broadcast>(unsqueezeV, targetShape, op::BroadcastType::BIDIRECTIONAL);

auto target4D = ov::op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64});

Loading