Skip to content

Commit 0d37a89

Browse files
authored
[Snippets] Fixed handling of Parameters with 1D shapes in SplitDimensionM pass (#29573)
### Details: - *Fixed handling of `Parameters` with 1D shapes in `SplitDimensionM` pass. No need to insert `Reshape` op on input with 1D shape since `M` dimension has index `1` from the end.* ### Tickets: - *149846*
1 parent 7c69a5e commit 0d37a89

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

src/common/snippets/include/snippets/pass/split_dimension_m.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,14 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {
8282

8383
void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);
8484

85+
static size_t get_dim_M(const ov::Shape& shape) {
86+
return *(shape.rbegin() + dim_M_index);
87+
}
88+
8589
size_t m_concurrency;
8690

8791
static const size_t min_kernel_m;
92+
static const size_t dim_M_index;
8893
};
8994
} // namespace pass
9095
} // namespace snippets

src/common/snippets/src/pass/split_dimension_m.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
#include "snippets/utils/utils.hpp"
99

1010
namespace {
11-
size_t get_dim_M(const ov::Shape& shape) {
12-
return *(shape.rbegin() + 1);
13-
}
1411
bool is_prime_number(size_t value) {
1512
if (ov::snippets::utils::one_of(value, 2lu, 3lu)) return true;
1613
if (value == 1 || value % 2 == 0 || value % 3 == 0) return false;
@@ -28,6 +25,7 @@ namespace snippets {
2825
namespace pass {
2926

3027
const size_t SplitDimensionM::min_kernel_m = 32;
28+
const size_t SplitDimensionM::dim_M_index = 1;
3129

3230
bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>& node) {
3331
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(node);
@@ -113,13 +111,15 @@ std::vector<size_t> SplitDimensionM::get_updated_order(const std::vector<size_t>
113111
}
114112

115113
ov::snippets::VectorDims SplitDimensionM::reshape_m_dim(ov::snippets::VectorDims shape, size_t m_index, size_t batch_m_dim, size_t new_m_dim) {
114+
OPENVINO_ASSERT(m_index < shape.size(), "Incorrect M index: it should be less than target shape rank");
116115
if (shape[m_index] == 1)
117116
return unsqueeze_m_dim(std::move(shape), m_index);
118117
shape[m_index] = new_m_dim;
119118
shape.insert(shape.begin() + m_index, batch_m_dim);
120119
return shape;
121120
}
122121
ov::snippets::VectorDims SplitDimensionM::unsqueeze_m_dim(ov::snippets::VectorDims shape, size_t m_index) {
122+
OPENVINO_ASSERT(m_index < shape.size(), "Incorrect M index: it should be less than target shape rank");
123123
shape.insert(shape.begin() + m_index, 1);
124124
return shape;
125125
}
@@ -194,6 +194,7 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
194194
};
195195

196196
auto get_updated_shape = [&](const ov::snippets::VectorDims& shape, size_t m_index, bool split_m_dim) {
197+
OPENVINO_ASSERT(m_index < shape.size(), "Dimension index must be less than shape rank");
197198
const auto current_m_dim = shape[m_index];
198199
OPENVINO_ASSERT(!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, "Incorrect shape for splitting!");
199200
const auto new_shape = split_m_dim ? reshape_m_dim(shape, m_index, batch_m_dim, new_m_dim) : unsqueeze_m_dim(shape, m_index);
@@ -205,7 +206,8 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
205206
const auto order_constant = ov::as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
206207
OPENVINO_ASSERT(order_constant != nullptr, "Transpose must have Constant order");
207208
const auto order = order_constant->cast_vector<size_t>();
208-
const auto m_index = is_input ? order[order.size() - 2] : order.size() - 2; // Index of M dimension in the previous order
209+
const auto forward_index = order.size() - 1 - dim_M_index;
210+
const auto m_index = is_input ? order[forward_index] : forward_index; // Index of M dimension in the previous order
209211
const auto new_order = get_updated_order(order, m_index);
210212
transpose->set_argument(1, std::make_shared<ov::op::v0::Constant>(order_constant->get_element_type(), ov::Shape{new_order.size()}, new_order));
211213
return m_index;
@@ -217,9 +219,13 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
217219
return;
218220

219221
const auto shape = param->get_partial_shape().get_shape();
222+
// if the index of dimension M is equal or greater than Shape rank, no need to reshape it.
223+
if (shape.size() <= dim_M_index)
224+
return;
225+
220226
const auto consumers = param->get_output_target_inputs(0);
221227
const auto shared_consumer = consumers.begin()->get_node()->shared_from_this();
222-
auto m_index = shape.size() - 2;
228+
auto m_index = shape.size() - 1 - dim_M_index;
223229
if (ov::is_type<ov::op::v1::Transpose>(shared_consumer)) {
224230
m_index = reshape_transpose(shared_consumer, true);
225231
}

src/common/snippets/tests/src/pass/mha_tokenization.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
240240
run();
241241
}
242242

243+
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) {
244+
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1}, {64}, {8, 64, 512}},
245+
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {}, {},
246+
{8, 1, 64, 512}, {8, 512, 512}});
247+
model = f.getOriginal();
248+
model_ref = f.getReference();
249+
config.set_concurrency(16);
250+
run();
251+
}
252+
243253
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Reshape_extraction) {
244254
const auto& f = MHAWithExtractedReshapeFunction(std::vector<PartialShape>{{400, 196, 80},
245255
{400, 80, 196},

src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,10 @@ std::shared_ptr<ov::Model> MHASelectSplitMFunction::initReference() const {
471471
auto param2 = std::make_shared<ov::opset1::Parameter>(precision, input_shapes[4]);
472472
ov::ParameterVector ngraphParam = {param0, param1, addParam, selectParam, param2};
473473

474-
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
474+
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) -> std::shared_ptr<ov::Node> {
475+
if (new_shape.empty()) {
476+
return node;
477+
}
475478
auto shape_const = ov::op::v0::Constant::create(ov::element::i32, {new_shape.size()}, new_shape);
476479
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
477480
};

0 commit comments

Comments
 (0)