Skip to content

Commit c4776cc

Browse files
authored
[Snippets][CPU] Fix full dim subtensor setting for port descriptor (#29457)
### Details: - Fix full dim subtensor value that is being set in BrgemmToBrgemmCPU pass for 1D scenario - Add check that subtensor shape is less or equal to tensor shape ### Tickets: - 163738
1 parent 3b5c618 commit c4776cc

File tree

7 files changed

+79
-29
lines changed

7 files changed

+79
-29
lines changed

src/common/snippets/include/snippets/lowered/port_descriptor.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class PortDescriptor {
4242
const Reg& get_reg() const { return m_reg; }
4343

4444
void set_shape(const VectorDims& tensor);
45-
void set_layout(const std::vector<size_t>& layout) { m_layout = layout; }
46-
void set_subtensor(const VectorDims& subtensor) { m_subtensor_shape = subtensor; }
45+
void set_layout(const std::vector<size_t>& layout);
46+
void set_subtensor(const VectorDims& subtensor);
4747
void set_reg(Reg reg) { m_reg = std::move(reg); }
4848
void set_reg_type(RegType type) { m_reg.type = type; }
4949
void set_reg_idx(size_t idx) { m_reg.idx = idx; }

src/common/snippets/src/lowered/pass/propagate_subtensors.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,12 @@ void propagate_updated_subtensor_through_loop(const LinearIR& linear_ir,
153153
// After subtensor propagation, the original shapes must be restored
154154
for (const auto& elem : original_shapes)
155155
elem.first->set_shape(elem.second);
156-
for (auto expr_it = begin; expr_it != shape_inference_end_it; expr_it++)
157-
(*expr_it)->updateShapes();
156+
for (auto expr_it = begin; expr_it != shape_inference_end_it; expr_it++) {
157+
const auto expr = *expr_it;
158+
if (ov::is_type<snippets::op::LoopBase>(expr->get_node()))
159+
continue;
160+
expr->updateShapes();
161+
}
158162
}
159163
} // namespace
160164

src/common/snippets/src/lowered/port_descriptor.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void PortDescriptor::validate_arguments() {
3838
// NCHW layout by default
3939
std::iota(m_layout.begin(), m_layout.end(), 0);
4040
}
41+
OPENVINO_ASSERT(m_subtensor_shape.size() <= m_tensor_shape->size(),
42+
"Snippets tensor descriptor: Subtensor shape must be less than or equal to tensor shape");
4143
OPENVINO_ASSERT(m_layout.size() == m_tensor_shape->size(), "Snippets tensor descriptor: Layout size must be equal to the shape size");
4244
}
4345

@@ -48,9 +50,22 @@ const VectorDims& PortDescriptor::get_shape() const {
4850

4951
void PortDescriptor::set_shape(const VectorDims& tensor) {
5052
OPENVINO_ASSERT(m_tensor_shape, "Failed to set_shape: Tensor Shape is nullptr");
53+
OPENVINO_ASSERT(m_subtensor_shape.size() <= tensor.size(),
54+
"Snippets tensor descriptor: Subtensor shape must be less than or equal to tensor shape");
5155
*m_tensor_shape = tensor;
5256
}
5357

58+
void PortDescriptor::set_layout(const std::vector<size_t>& layout) {
59+
OPENVINO_ASSERT(layout.size() == m_tensor_shape->size(),
60+
"Snippets tensor descriptor: Layout size must be equal to the shape size");
61+
m_layout = layout;
62+
}
63+
void PortDescriptor::set_subtensor(const VectorDims& subtensor) {
64+
OPENVINO_ASSERT(subtensor.size() <= m_tensor_shape->size(),
65+
"Subtensor shape must be less than or equal to tensor shape");
66+
m_subtensor_shape = subtensor;
67+
}
68+
5469
void PortDescriptor::set_subtensor_dim(size_t idx, VectorDims::value_type value) {
5570
OPENVINO_ASSERT(idx < m_subtensor_shape.size(), "Failed to set subtensor value: idx should be less than size");
5671
*(m_subtensor_shape.rbegin() + idx) = value;

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace {
2727
template <typename T>
2828
void set_full_port_desc(const T& port) {
2929
const auto& shape_rank = port.get_partial_shape().size();
30-
static const std::vector<size_t> full_dim_subtensor(std::min(shape_rank, static_cast<size_t>(2)),
31-
ov::snippets::utils::get_full_dim_value());
30+
const std::vector<size_t> full_dim_subtensor(std::min(shape_rank, static_cast<size_t>(2)),
31+
ov::snippets::utils::get_full_dim_value());
3232
PortDescriptorUtils::set_port_descriptor(port, full_dim_subtensor);
3333
}
3434
} // namespace

src/plugins/intel_cpu/src/transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,18 @@ EltwiseToEltwiseTPP::EltwiseToEltwiseTPP() {
4040
ov::is_type<ov::snippets::op::ReduceBase>(node) ? ov::snippets::utils::get_full_dim_value() : 64;
4141
ov::replace_node_update_name(node, tpp_eltwise);
4242
for (size_t i = 0; i < node->get_input_size(); i++) {
43-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->input(i), {M_block, N_block});
43+
auto subtensor = snippets::VectorDims{M_block, N_block};
44+
if (tpp_eltwise->get_input_partial_shape(i).size() < 2) {
45+
subtensor = snippets::VectorDims{N_block};
46+
}
47+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->input(i), subtensor);
4448
}
4549

46-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->output(0), {M_block, N_block});
50+
auto subtensor = snippets::VectorDims{M_block, N_block};
51+
if (tpp_eltwise->output(0).get_partial_shape().size() < 2) {
52+
subtensor = snippets::VectorDims{N_block};
53+
}
54+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->output(0), subtensor);
4755

4856
return true;
4957
};

src/plugins/intel_cpu/src/transformations/tpp/x64/pass/fuse_tpp_to_equations.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,20 @@ bool FuseTPPToEquations::fuse_from_root(const NodePtr& root, const std::shared_p
8282
kv.second = equation;
8383
replace_nodes(m, {}, node_replace_map);
8484
for (const auto& in : equation->inputs()) {
85-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(in, root_subtensor);
85+
auto subtensor = root_subtensor;
86+
if (in.get_partial_shape().size() < root_subtensor.size()) {
87+
subtensor.erase(subtensor.begin(),
88+
subtensor.begin() + (root_subtensor.size() - in.get_partial_shape().size()));
89+
}
90+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(in, subtensor);
91+
}
92+
auto subtensor = root_subtensor;
93+
const auto& out = equation->output(0);
94+
if (out.get_partial_shape().size() < root_subtensor.size()) {
95+
subtensor.erase(subtensor.begin(),
96+
subtensor.begin() + (root_subtensor.size() - out.get_partial_shape().size()));
8697
}
87-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(equation->output(0), root_subtensor);
98+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(equation->output(0), subtensor);
8899
return true;
89100
}
90101

src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/buffer_allocation.cpp

+31-19
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,24 @@ class BufferAllocationCPUTest : public testing::TestWithParam<BufferAllocationCP
113113

114114
virtual std::shared_ptr<ov::Model> GetModel(const std::vector<ov::PartialShape>& shapes) const = 0;
115115

116-
void MarkOp(const std::shared_ptr<ov::Node>& node, const std::vector<size_t>& subtensor) const {
117-
for (const auto& input : node->inputs())
116+
void MarkOp(const std::shared_ptr<ov::Node>& node,
117+
const std::vector<std::vector<size_t>>& in_subtensors,
118+
const std::vector<std::vector<size_t>>& out_subtensors) const {
119+
OPENVINO_ASSERT(in_subtensors.size() == node->inputs().size(), "Incorrect count of input subtensors");
120+
OPENVINO_ASSERT(out_subtensors.size() == node->outputs().size(), "Incorrect count of output subtensors");
121+
// Mark input and output ports with the first supported subtensor
122+
for (size_t i = 0; i < node->inputs().size(); ++i) {
123+
const auto& input = node->input(i);
118124
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(
119-
input, std::make_shared<ov::snippets::lowered::PortDescriptor>(input, subtensor));
120-
for (const auto& output : node->outputs())
125+
input,
126+
std::make_shared<ov::snippets::lowered::PortDescriptor>(input, in_subtensors[i]));
127+
}
128+
for (size_t i = 0; i < node->outputs().size(); ++i) {
129+
const auto& output = node->output(i);
121130
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(
122-
output, std::make_shared<ov::snippets::lowered::PortDescriptor>(output, subtensor));
131+
output,
132+
std::make_shared<ov::snippets::lowered::PortDescriptor>(output, out_subtensors[i]));
133+
}
123134
}
124135

125136
ov::snippets::lowered::LinearIR m_linear_ir;
@@ -173,12 +184,12 @@ class MHAFP32BufferAllocationTest : public BufferAllocationCPUTest {
173184

174185
const auto body = std::make_shared<ov::Model>(std::make_shared<ov::op::v0::Result>(relu2), ov::ParameterVector{parameter0, parameter1, parameter2});
175186

176-
MarkOp(load_reshape, subtensor_scalar);
177-
MarkOp(store, subtensor_scalar);
178-
MarkOp(power, subtensor_power);
187+
MarkOp(load_reshape, {subtensor_scalar}, {subtensor_scalar});
188+
MarkOp(store, {subtensor_scalar}, {subtensor_scalar});
189+
MarkOp(power, {subtensor_power}, {subtensor_power});
179190

180-
MarkOp(brgemm_cpu0, subtensor_full);
181-
MarkOp(brgemm_cpu1, subtensor_full);
191+
MarkOp(brgemm_cpu0, {subtensor_full, subtensor_full}, {subtensor_full});
192+
MarkOp(brgemm_cpu1, {subtensor_full, subtensor_full}, {subtensor_full});
182193

183194
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(load_reshape->input(0))->set_layout(order);
184195

@@ -192,6 +203,7 @@ class MHABF16AMXBufferAllocationTest : public BufferAllocationCPUTest {
192203
const auto subtensor_scalar = std::vector<size_t>{1};
193204
const auto subtensor_power = std::vector<size_t>{1, ov::snippets::utils::get_full_dim_value()};
194205
const auto subtensor_full = std::vector<size_t>(2, ov::snippets::utils::get_full_dim_value());
206+
const auto subtensor_flat = std::vector<size_t>(1, ov::snippets::utils::get_full_dim_value());
195207

196208
OPENVINO_ASSERT(shapes.size() == 3, "Incorrect count of input shapes");
197209
const auto parameter0 = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, shapes[0]);
@@ -234,16 +246,16 @@ class MHABF16AMXBufferAllocationTest : public BufferAllocationCPUTest {
234246

235247
const auto body = std::make_shared<ov::Model>(std::make_shared<ov::op::v0::Result>(relu2), ov::ParameterVector{parameter0, parameter1, parameter2});
236248

237-
MarkOp(load_reshape, subtensor_scalar);
238-
MarkOp(store, subtensor_scalar);
239-
MarkOp(power, subtensor_power);
249+
MarkOp(load_reshape, {subtensor_scalar}, {subtensor_scalar});
250+
MarkOp(store, {subtensor_scalar}, {subtensor_scalar});
251+
MarkOp(power, {subtensor_power}, {subtensor_power});
240252

241-
MarkOp(brgemm_cpu0, subtensor_full);
242-
MarkOp(brgemm_cpu1, subtensor_full);
243-
MarkOp(brgemm_copyb0, subtensor_full);
244-
MarkOp(brgemm_copyb1, subtensor_full);
245-
MarkOp(scratch0, subtensor_full);
246-
MarkOp(scratch1, subtensor_full);
253+
MarkOp(brgemm_cpu0, {subtensor_full, subtensor_full, subtensor_flat}, {subtensor_full});
254+
MarkOp(brgemm_cpu1, {subtensor_full, subtensor_full, subtensor_flat}, {subtensor_full});
255+
MarkOp(brgemm_copyb0, {subtensor_flat}, {subtensor_full});
256+
MarkOp(brgemm_copyb1, {subtensor_flat}, {subtensor_full});
257+
MarkOp(scratch0, {}, {subtensor_flat});
258+
MarkOp(scratch1, {}, {subtensor_flat});
247259

248260
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(load_reshape->input(0))->set_layout(order);
249261

0 commit comments

Comments
 (0)