3
3
//
4
4
5
5
#include " jit_brgemm_emitter.hpp"
6
+
6
7
#include " emitters/snippets/x64/jit_snippets_emitters.hpp"
7
8
#include " transformations/tpp/x64/op/brgemm.hpp"
8
9
@@ -28,18 +29,15 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
28
29
const auto & input_1_desc = expr->get_input_port_descriptor (1 );
29
30
const auto & output_desc = expr->get_output_port_descriptor (0 );
30
31
31
- std::vector<size_t > leading_dimensions {brgemm_node->get_input_stride (0 ),
32
- brgemm_node->get_input_stride (1 ),
33
- brgemm_node->get_output_stride (0 )};
32
+ std::vector<size_t > leading_dimensions{brgemm_node->get_input_stride (0 ),
33
+ brgemm_node->get_input_stride (1 ),
34
+ brgemm_node->get_output_stride (0 )};
34
35
35
36
auto in_0_prec = ov_to_xsmm_dtype (brgemm_node->get_input_element_type (0 ));
36
37
auto in_1_prec = ov_to_xsmm_dtype (brgemm_node->get_input_element_type (1 ));
37
- exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
38
- LIBXSMM_DATATYPE_I32 :
39
- LIBXSMM_DATATYPE_F32;
40
- auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ?
41
- LIBXSMM_DATATYPE_I32 :
42
- LIBXSMM_DATATYPE_F32;
38
+ exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? LIBXSMM_DATATYPE_I32
39
+ : LIBXSMM_DATATYPE_F32;
40
+ auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? LIBXSMM_DATATYPE_I32 : LIBXSMM_DATATYPE_F32;
43
41
44
42
const auto beta = brgemm_node->get_beta ();
45
43
OV_CPU_JIT_EMITTER_ASSERT (beta == 0 || beta == 1 , " Detected unsupported beta value: " + std::to_string (beta));
@@ -54,18 +52,14 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
54
52
const auto N = static_cast <libxsmm_blasint>(*subtensor_in1.rbegin ());
55
53
56
54
const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32;
57
- const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
55
+ const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
58
56
const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8;
59
- OV_CPU_JIT_EMITTER_ASSERT (is_f32_gemm ||
60
- (is_bf16_gemm && K % 2 == 0 ) ||
61
- (is_i8_gemm && K % 4 == 0 ),
57
+ OV_CPU_JIT_EMITTER_ASSERT (is_f32_gemm || (is_bf16_gemm && K % 2 == 0 ) || (is_i8_gemm && K % 4 == 0 ),
62
58
" Unsupported parameter combination for kernel configuration" );
63
59
64
- m_compile_flags = is_f32_gemm ?
65
- LIBXSMM_GEMM_FLAGS (' N' , ' N' ) :
66
- LIBXSMM_GEMM_VNNI_FLAGS (' N' , ' N' , ' V' , ' N' ) |
67
- LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG |
68
- LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;
60
+ m_compile_flags = is_f32_gemm ? LIBXSMM_GEMM_FLAGS (' N' , ' N' )
61
+ : LIBXSMM_GEMM_VNNI_FLAGS (' N' , ' N' , ' V' , ' N' ) |
62
+ LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;
69
63
70
64
if (beta == 0 )
71
65
m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;
@@ -79,9 +73,15 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
79
73
m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED;
80
74
}
81
75
82
- m_shape = libxsmm_create_gemm_shape (N, M, K,
83
- io_strides[1 ], io_strides[0 ], io_strides[2 ],
84
- in_1_prec, in_0_prec, out_0_prec,
76
+ m_shape = libxsmm_create_gemm_shape (N,
77
+ M,
78
+ K,
79
+ io_strides[1 ],
80
+ io_strides[0 ],
81
+ io_strides[2 ],
82
+ in_1_prec,
83
+ in_0_prec,
84
+ out_0_prec,
85
85
exec_dtype);
86
86
m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE;
87
87
}
@@ -91,7 +91,7 @@ std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(
91
91
return {{element::f32, element::f32}};
92
92
}
93
93
94
- void BrgemmTppEmitter::validate_arguments (const std::vector<size_t > & in, const std::vector<size_t > & out) const {
94
+ void BrgemmTppEmitter::validate_arguments (const std::vector<size_t >& in, const std::vector<size_t >& out) const {
95
95
OV_CPU_JIT_EMITTER_ASSERT (in.size () == 2 , " Expects 2 input regs, got" + std::to_string (in.size ()));
96
96
OV_CPU_JIT_EMITTER_ASSERT (out.size () == 1 , " Expects 1 output reg, got" + std::to_string (out.size ()));
97
97
}
@@ -100,7 +100,7 @@ const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
100
100
return COMPILE_TPP_KERNEL (libxsmm_dispatch_gemm (m_shape, m_compile_flags, m_prefetching_flags));
101
101
}
102
102
103
- void BrgemmTppEmitter::execute_brgemm_kernel (libxsmm_gemmfunction brg_kernel, void * in0, void * in1, void * out0) {
103
+ void BrgemmTppEmitter::execute_brgemm_kernel (libxsmm_gemmfunction brg_kernel, void * in0, void * in1, void * out0) {
104
104
libxsmm_gemm_param gemm_p;
105
105
gemm_p.a .primary = in1;
106
106
gemm_p.b .primary = in0;
0 commit comments