8
8
#include " snippets/utils/utils.hpp"
9
9
10
10
namespace {
11
- size_t get_dim_M (const ov::Shape& shape) {
12
- return *(shape.rbegin () + 1 );
13
- }
14
11
bool is_prime_number (size_t value) {
15
12
if (ov::snippets::utils::one_of (value, 2lu, 3lu)) return true ;
16
13
if (value == 1 || value % 2 == 0 || value % 3 == 0 ) return false ;
@@ -28,6 +25,7 @@ namespace snippets {
28
25
namespace pass {
29
26
30
27
const size_t SplitDimensionM::min_kernel_m = 32 ;
28
+ const size_t SplitDimensionM::dim_M_index = 1 ;
31
29
32
30
bool SplitDimensionM::is_supported_matmul (const std::shared_ptr<const ov::Node>& node) {
33
31
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(node);
@@ -113,13 +111,15 @@ std::vector<size_t> SplitDimensionM::get_updated_order(const std::vector<size_t>
113
111
}
114
112
115
113
ov::snippets::VectorDims SplitDimensionM::reshape_m_dim (ov::snippets::VectorDims shape, size_t m_index, size_t batch_m_dim, size_t new_m_dim) {
114
+ OPENVINO_ASSERT (m_index < shape.size (), " Incorrect M index: it should be less than target shape rank" );
116
115
if (shape[m_index] == 1 )
117
116
return unsqueeze_m_dim (std::move (shape), m_index);
118
117
shape[m_index] = new_m_dim;
119
118
shape.insert (shape.begin () + m_index, batch_m_dim);
120
119
return shape;
121
120
}
122
121
ov::snippets::VectorDims SplitDimensionM::unsqueeze_m_dim (ov::snippets::VectorDims shape, size_t m_index) {
122
+ OPENVINO_ASSERT (m_index < shape.size (), " Incorrect M index: it should be less than target shape rank" );
123
123
shape.insert (shape.begin () + m_index, 1 );
124
124
return shape;
125
125
}
@@ -194,6 +194,7 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
194
194
};
195
195
196
196
auto get_updated_shape = [&](const ov::snippets::VectorDims& shape, size_t m_index, bool split_m_dim) {
197
+ OPENVINO_ASSERT (m_index < shape.size (), " Dimension index must be less than shape rank" );
197
198
const auto current_m_dim = shape[m_index];
198
199
OPENVINO_ASSERT (!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, " Incorrect shape for splitting!" );
199
200
const auto new_shape = split_m_dim ? reshape_m_dim (shape, m_index, batch_m_dim, new_m_dim) : unsqueeze_m_dim (shape, m_index);
@@ -205,7 +206,8 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
205
206
const auto order_constant = ov::as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr (1 ));
206
207
OPENVINO_ASSERT (order_constant != nullptr , " Transpose must have Constant order" );
207
208
const auto order = order_constant->cast_vector <size_t >();
208
- const auto m_index = is_input ? order[order.size () - 2 ] : order.size () - 2 ; // Index of M dimension in the previous order
209
+ const auto forward_index = order.size () - 1 - dim_M_index;
210
+ const auto m_index = is_input ? order[forward_index] : forward_index; // Index of M dimension in the previous order
209
211
const auto new_order = get_updated_order (order, m_index);
210
212
transpose->set_argument (1 , std::make_shared<ov::op::v0::Constant>(order_constant->get_element_type (), ov::Shape{new_order.size ()}, new_order));
211
213
return m_index;
@@ -217,9 +219,13 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
217
219
return ;
218
220
219
221
const auto shape = param->get_partial_shape ().get_shape ();
222
+ // if the index of dimension M is equal or greater than Shape rank, no need to reshape it.
223
+ if (shape.size () <= dim_M_index)
224
+ return ;
225
+
220
226
const auto consumers = param->get_output_target_inputs (0 );
221
227
const auto shared_consumer = consumers.begin ()->get_node ()->shared_from_this ();
222
- auto m_index = shape.size () - 2 ;
228
+ auto m_index = shape.size () - 1 - dim_M_index ;
223
229
if (ov::is_type<ov::op::v1::Transpose>(shared_consumer)) {
224
230
m_index = reshape_transpose (shared_consumer, true );
225
231
}
0 commit comments