@@ -113,13 +113,24 @@ class BufferAllocationCPUTest : public testing::TestWithParam<BufferAllocationCP
113
113
114
114
virtual std::shared_ptr<ov::Model> GetModel (const std::vector<ov::PartialShape>& shapes) const = 0;
115
115
116
- void MarkOp (const std::shared_ptr<ov::Node>& node, const std::vector<size_t >& subtensor) const {
117
- for (const auto & input : node->inputs ())
116
+ void MarkOp (const std::shared_ptr<ov::Node>& node,
117
+ const std::vector<std::vector<size_t >>& in_subtensors,
118
+ const std::vector<std::vector<size_t >>& out_subtensors) const {
119
+ OPENVINO_ASSERT (in_subtensors.size () == node->inputs ().size (), " Incorrect count of input subtensors" );
120
+ OPENVINO_ASSERT (out_subtensors.size () == node->outputs ().size (), " Incorrect count of output subtensors" );
121
+ // Mark input and output ports with the first supported subtensor
122
+ for (size_t i = 0 ; i < node->inputs ().size (); ++i) {
123
+ const auto & input = node->input (i);
118
124
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr (
119
- input, std::make_shared<ov::snippets::lowered::PortDescriptor>(input, subtensor));
120
- for (const auto & output : node->outputs ())
125
+ input,
126
+ std::make_shared<ov::snippets::lowered::PortDescriptor>(input, in_subtensors[i]));
127
+ }
128
+ for (size_t i = 0 ; i < node->outputs ().size (); ++i) {
129
+ const auto & output = node->output (i);
121
130
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr (
122
- output, std::make_shared<ov::snippets::lowered::PortDescriptor>(output, subtensor));
131
+ output,
132
+ std::make_shared<ov::snippets::lowered::PortDescriptor>(output, out_subtensors[i]));
133
+ }
123
134
}
124
135
125
136
ov::snippets::lowered::LinearIR m_linear_ir;
@@ -173,12 +184,12 @@ class MHAFP32BufferAllocationTest : public BufferAllocationCPUTest {
173
184
174
185
const auto body = std::make_shared<ov::Model>(std::make_shared<ov::op::v0::Result>(relu2), ov::ParameterVector{parameter0, parameter1, parameter2});
175
186
176
- MarkOp (load_reshape, subtensor_scalar);
177
- MarkOp (store, subtensor_scalar);
178
- MarkOp (power, subtensor_power);
187
+ MarkOp (load_reshape, { subtensor_scalar}, {subtensor_scalar} );
188
+ MarkOp (store, { subtensor_scalar}, {subtensor_scalar} );
189
+ MarkOp (power, { subtensor_power}, {subtensor_power} );
179
190
180
- MarkOp (brgemm_cpu0, subtensor_full);
181
- MarkOp (brgemm_cpu1, subtensor_full);
191
+ MarkOp (brgemm_cpu0, { subtensor_full, subtensor_full}, {subtensor_full} );
192
+ MarkOp (brgemm_cpu1, { subtensor_full, subtensor_full}, {subtensor_full} );
182
193
183
194
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr (load_reshape->input (0 ))->set_layout (order);
184
195
@@ -192,6 +203,7 @@ class MHABF16AMXBufferAllocationTest : public BufferAllocationCPUTest {
192
203
const auto subtensor_scalar = std::vector<size_t >{1 };
193
204
const auto subtensor_power = std::vector<size_t >{1 , ov::snippets::utils::get_full_dim_value ()};
194
205
const auto subtensor_full = std::vector<size_t >(2 , ov::snippets::utils::get_full_dim_value ());
206
+ const auto subtensor_flat = std::vector<size_t >(1 , ov::snippets::utils::get_full_dim_value ());
195
207
196
208
OPENVINO_ASSERT (shapes.size () == 3 , " Incorrect count of input shapes" );
197
209
const auto parameter0 = std::make_shared<ov::op::v0::Parameter>(ov::element::bf16, shapes[0 ]);
@@ -234,16 +246,16 @@ class MHABF16AMXBufferAllocationTest : public BufferAllocationCPUTest {
234
246
235
247
const auto body = std::make_shared<ov::Model>(std::make_shared<ov::op::v0::Result>(relu2), ov::ParameterVector{parameter0, parameter1, parameter2});
236
248
237
- MarkOp (load_reshape, subtensor_scalar);
238
- MarkOp (store, subtensor_scalar);
239
- MarkOp (power, subtensor_power);
249
+ MarkOp (load_reshape, { subtensor_scalar}, {subtensor_scalar} );
250
+ MarkOp (store, { subtensor_scalar}, {subtensor_scalar} );
251
+ MarkOp (power, { subtensor_power}, {subtensor_power} );
240
252
241
- MarkOp (brgemm_cpu0, subtensor_full);
242
- MarkOp (brgemm_cpu1, subtensor_full);
243
- MarkOp (brgemm_copyb0, subtensor_full);
244
- MarkOp (brgemm_copyb1, subtensor_full);
245
- MarkOp (scratch0, subtensor_full );
246
- MarkOp (scratch1, subtensor_full );
253
+ MarkOp (brgemm_cpu0, { subtensor_full, subtensor_full, subtensor_flat}, {subtensor_full} );
254
+ MarkOp (brgemm_cpu1, { subtensor_full, subtensor_full, subtensor_flat}, {subtensor_full} );
255
+ MarkOp (brgemm_copyb0, {subtensor_flat}, { subtensor_full} );
256
+ MarkOp (brgemm_copyb1, {subtensor_flat}, { subtensor_full} );
257
+ MarkOp (scratch0, {}, {subtensor_flat} );
258
+ MarkOp (scratch1, {}, {subtensor_flat} );
247
259
248
260
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr (load_reshape->input (0 ))->set_layout (order);
249
261
0 commit comments