Skip to content

Commit abc20bd

Browse files
committed
build brgemm
1 parent c27f796 commit abc20bd

35 files changed

+862
-238
lines changed

src/common/snippets/docs/mha_optimization_guide.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d
123123

124124
### Blocking Parameters
125125

126-
The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
126+
The heuristics for determining the optimal block sizes can be found in [GemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
127127

128128
**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**
129129

@@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
141141
In local experiments, some transformations might be worth to change:
142142
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
143143
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
144-
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
144+
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `GemmCPUBlocking`.
145145
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
146146
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
147147
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.

src/common/snippets/src/lowered/linear_ir.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,16 @@ LinearIR::exprIt LinearIR::replace_with_expr(const std::vector<ExpressionPtr>& o
432432
const auto input_ports = new_expr_it->get()->get_input_ports();
433433
const auto output_ports = new_expr_it->get()->get_output_ports();
434434
for (const auto& old_expr : old_exprs) {
435-
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
436-
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_input_port(i), input_ports);
437-
for (size_t i = 0; i < old_expr->get_input_count(); ++i)
438-
m_loop_manager->replace_loop_ports(loop_ids, old_expr->get_output_port(i), output_ports);
435+
for (size_t i = 0; i < old_expr->get_input_count(); ++i) {
436+
m_loop_manager->replace_loop_ports(loop_ids,
437+
old_expr->get_input_port(i),
438+
{new_expr_it->get()->get_input_port(i)});
439+
}
440+
for (size_t i = 0; i < old_expr->get_output_count(); ++i) {
441+
m_loop_manager->replace_loop_ports(loop_ids,
442+
old_expr->get_output_port(i),
443+
{new_expr_it->get()->get_output_port(i)});
444+
}
439445
erase(find(old_expr));
440446
}
441447
return new_expr_it;

src/common/snippets/src/op/brgemm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
5050
}
5151

5252
void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) {
53-
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
53+
INTERNAL_OP_SCOPE(GemmCPU_constructor_validate_and_infer_types);
5454

5555
// During ctor call, Brgemm doesn't know his port descriptors.
5656
// So we use explicit layouts from parameters
@@ -100,7 +100,7 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
100100
ov::element::Type Brgemm::get_output_type() const {
101101
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
102102
if (output_type == element::undefined) {
103-
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
103+
OPENVINO_THROW("GemmCPU node has incompatible input element types: " +
104104
get_input_element_type(0).get_type_name() +
105105
" and " +
106106
get_input_element_type(1).get_type_name());

src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "transformations/snippets/common/op/fused_mul_add.hpp"
2525
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
2626
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
27+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
2728
#include "transformations/snippets/x64/op/load_convert.hpp"
2829
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
2930
#include "transformations/snippets/x64/op/store_convert.hpp"
@@ -260,6 +261,10 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
260261

261262
// Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes
262263
// additional arguments
264+
jitters[intel_cpu::GemmCPU::get_type_info_static()] =
265+
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
266+
configurator->get_kernel_executor_table(),
267+
compiled_kernel_cache);
263268
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] =
264269
CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
265270
configurator->get_kernel_executor_table(),
@@ -431,7 +436,7 @@ std::shared_ptr<snippets::Generator> intel_cpu::CPUGenerator::clone() const {
431436

432437
ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(const ov::Output<ov::Node>& out) const {
433438
const auto op = out.get_node_shared_ptr();
434-
if (is_type<intel_cpu::BrgemmCPU>(op) ||
439+
if (is_type<intel_cpu::GemmCPU>(op) ||
435440
#ifdef SNIPPETS_LIBXSMM_TPP
436441
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
437442
is_type<intel_cpu::tpp::op::Scalar>(op) ||

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "emitters/plugin/x64/utils.hpp"
1111
#include "emitters/snippets/x64/utils.hpp"
1212
#include "snippets/utils/utils.hpp"
13-
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
13+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
1414

1515
using namespace Xbyak;
1616
using namespace dnnl::impl;

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp

+59-32
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
#include "emitters/plugin/x64/utils.hpp"
88
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
99
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
10+
#include "emitters/snippets/x64/kernel_executors/brgemm_batched.hpp"
1011
#include "snippets/utils/utils.hpp"
1112
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
13+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
1214
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
1315
#include "utils.hpp"
1416

@@ -26,44 +28,67 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
2628
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
2729
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
2830
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
29-
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
30-
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
31-
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
32-
const auto brgemm_type = brgemm_node->get_type();
33-
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
34-
if (m_is_with_amx) {
35-
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
31+
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
32+
const auto& gemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
33+
const auto& brg0Prc = gemm_node->get_input_element_type(0);
34+
const auto& brg1Prc = gemm_node->get_input_element_type(1);
35+
const auto brgemm_type = gemm_node->get_type();
36+
m_is_with_amx = false;
37+
38+
BrgemmBatchedKernelConfig kernel_config(brg0Prc,
39+
brg1Prc,
40+
with_compensations(brgemm_type),
41+
brgemm_utils::get_primitive_isa(brg0Prc, false));
3642
m_kernel_executor =
37-
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
43+
kernel_table->register_kernel<BrgemmBatchedKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
44+
45+
m_memory_offsets = {gemm_node->get_offset_a(), gemm_node->get_offset_b(), gemm_node->get_offset_c()};
46+
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
47+
utils::get_buffer_cluster_id(expr->get_input_port(1)),
48+
utils::get_buffer_cluster_id(expr->get_output_port(0))};
49+
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
50+
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
51+
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
52+
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
53+
const auto brgemm_type = brgemm_node->get_type();
54+
m_is_with_amx = brgemm_utils::with_amx(brgemm_type);
55+
if (m_is_with_amx) {
56+
BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true));
57+
m_kernel_executor =
58+
kernel_table->register_kernel<BrgemmAMXKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
59+
} else {
60+
BrgemmKernelConfig kernel_config(brg0Prc,
61+
brg1Prc,
62+
with_compensations(brgemm_type),
63+
brgemm_utils::get_primitive_isa(brg0Prc, false));
64+
m_kernel_executor =
65+
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
66+
}
67+
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
68+
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
69+
// for both static and the 1st dynamic shapes.
70+
OV_CPU_JIT_EMITTER_ASSERT(
71+
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
72+
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
73+
"Jit emitter is called when the shapes are unknown");
74+
75+
m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
76+
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
77+
utils::get_buffer_cluster_id(expr->get_input_port(1)),
78+
utils::get_buffer_cluster_id(expr->get_output_port(0))};
79+
if (with_scratchpad(brgemm_type)) {
80+
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
81+
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
82+
}
3883
} else {
39-
BrgemmKernelConfig kernel_config(brg0Prc,
40-
brg1Prc,
41-
with_compensations(brgemm_type),
42-
brgemm_utils::get_primitive_isa(brg0Prc, false));
43-
m_kernel_executor =
44-
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
45-
}
46-
// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
47-
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
48-
// for both static and the 1st dynamic shapes.
49-
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) &&
50-
!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()),
51-
"Jit emitter is called when the shapes are unknown");
52-
53-
m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
54-
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)),
55-
utils::get_buffer_cluster_id(expr->get_input_port(1)),
56-
utils::get_buffer_cluster_id(expr->get_output_port(0))};
57-
if (with_scratchpad(brgemm_type)) {
58-
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
59-
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
84+
OV_CPU_JIT_EMITTER_THROW("got unsupported node type");
6085
}
6186
}
6287

6388
std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
6489
const std::shared_ptr<ov::Node>& node) {
65-
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
66-
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
90+
const auto brgemm = as_type_ptr<ov::intel_cpu::GemmCPU>(node);
91+
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects GemmCPU node");
6792
using brgemm_utils::BRGEMM_TYPE;
6893
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
6994
return {{element::f32, element::f32}};
@@ -83,7 +108,7 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
83108
{element::bf16, element::bf16, element::u8},
84109
{element::f16, element::f16, element::u8}};
85110
}
86-
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
111+
OV_CPU_JIT_EMITTER_THROW("got GemmCPU node with unsupported type");
87112
}
88113

89114
void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
@@ -103,6 +128,8 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
103128
emit_call<BrgemmAMXKernelExecutor>(mem_ptrs_idxs);
104129
} else if (std::dynamic_pointer_cast<BrgemmKernelExecutor>(m_kernel_executor)) {
105130
emit_call<BrgemmKernelExecutor>(mem_ptrs_idxs);
131+
} else if (std::dynamic_pointer_cast<BrgemmBatchedKernelExecutor>(m_kernel_executor)) {
132+
emit_call<BrgemmBatchedKernelExecutor>(mem_ptrs_idxs);
106133
} else {
107134
OV_CPU_JIT_EMITTER_THROW("uknown execuor type");
108135
}

src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "common/utils.hpp"
88
#include "dnnl_extension_utils.h"
99
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
10-
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
10+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
1111
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
1212

1313
using namespace Xbyak;

src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <cpu/x64/amx_tile_configure.hpp>
88

9-
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
9+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
1010
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
1111

1212
#define INNER_K_BLK(dtype) static_cast<dnnl_dim_t>((brgemm_utils::repacking::compute_inner_k_block(in0_dtype)))
@@ -293,7 +293,7 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c
293293

294294
if (K_tail != 0) {
295295
if (config.need_copy_a(K_tail)) {
296-
auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;
296+
auto* tr_src = scratch + GemmCPU::SCRATCH_BYTE_SIZE;
297297

298298
execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
299299
src_ptr = tr_src;

src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "common/utils.hpp"
88
#include "dnnl_extension_utils.h"
99
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
10+
#include "transformations/snippets/x64/op/gemm_cpu.hpp"
1011
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
1112

1213
#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
@@ -163,7 +164,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
163164
const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout());
164165
const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout());
165166
auto in0_subtensor = input_pds[0]->get_subtensor();
167+
OPENVINO_ASSERT(!in0_subtensor.empty(), "Incorrect in0 subtensor size");
166168
auto in1_subtensor = input_pds[1]->get_subtensor();
169+
OPENVINO_ASSERT(!in1_subtensor.empty(), "Incorrect in1 subtensor size");
167170

168171
// Need to update M, K, N
169172
// 1. If the original value in subtensor is `FULL_DIM`, it means that
@@ -254,11 +257,15 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
254257
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
255258
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));
256259

257-
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
258-
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
259-
// In case of data repacking LDB is chosen in accordance with repacking buffer size
260-
if (with_repacking(brgemm_node->get_type())) {
261-
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
260+
if (is_type<ov::intel_cpu::BrgemmCPU>(expr->get_node())) {
261+
} else if (is_type<ov::intel_cpu::GemmCPU>(expr->get_node())) {
262+
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::GemmCPU>(expr->get_node());
263+
// In case of data repacking LDB is chosen in accordance with repacking buffer size
264+
if (with_repacking(brgemm_node->get_type())) {
265+
LDB = DIM_CAST(brgemm_utils::repacking::compute_repacked_n_dim(LDB, brgemm_node->get_input_element_type(1)));
266+
}
267+
} else {
268+
OV_CPU_JIT_EMITTER_ASSERT(false, "Got invalid node type in update_config");
262269
}
263270

264271
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);

0 commit comments

Comments
 (0)