Skip to content

Commit c92de3b

Browse files
CaoEpytorchmergebot
authored andcommitted
Add BRGEMM API versioning to be compatible with different oneDNN versions (pytorch#138184)
oneDNN v3.6 updated the ukernel APIs of `brgemm` and `brgemm_pack_B`. Considering the upgrade of oneDNN, ukernel API versioning is needed to be compatible with different oneDNN versions. Pull Request resolved: pytorch#138184 Approved by: https://github.com/jgong5, https://github.com/peterbell10
1 parent 299dbcd commit c92de3b

File tree

3 files changed

+71
-46
lines changed

3 files changed

+71
-46
lines changed

aten/src/ATen/native/CPUBlas.cpp

+64-36
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,21 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
4545
#endif // USE_FBGEMM
4646

4747
#if AT_MKLDNN_ENABLED()
48-
#include <oneapi/dnnl/dnnl_version.h>
49-
#endif // oneDNN
50-
51-
#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5)
48+
#include <ideep.hpp>
49+
// Add uKernel API versioning to be compatible with different oneDNN versions
50+
// oneDNN 3.6.x updates the ukernel APIs of brgemm and brgemm_pack_B
51+
// brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C
52+
#if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5)
53+
#define ONEDNN_UKERNEL_1
54+
#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6)
55+
#define ONEDNN_UKERNEL_2
56+
#endif
57+
#if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))))
58+
#define ONEDNN_UKERNEL_ENABLED
59+
#endif
60+
#endif // AT_MKLDNN_ENABLED()
5261

53-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
62+
#if defined(ONEDNN_UKERNEL_ENABLED)
5463
#include <oneapi/dnnl/dnnl_ukernel.hpp>
5564
#include <oneapi/dnnl/dnnl.hpp>
5665
#endif // oneDNN BRGEMM
@@ -847,7 +856,7 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
847856
}
848857

849858
// oneDNN BRGEMM
850-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
859+
#if defined(ONEDNN_UKERNEL_ENABLED)
851860
struct BrgemmKey {
852861
int64_t M;
853862
int64_t N;
@@ -859,8 +868,8 @@ struct BrgemmKey {
859868
ScalarType dt_a;
860869
ScalarType dt_b;
861870
ScalarType dt_c;
862-
float alpha;
863-
float beta;
871+
bool add_C;
872+
864873
BrgemmKey(
865874
int64_t M,
866875
int64_t N,
@@ -872,8 +881,7 @@ struct BrgemmKey {
872881
ScalarType dt_a,
873882
ScalarType dt_b,
874883
ScalarType dt_c,
875-
float alpha,
876-
float beta)
884+
bool add_C)
877885
: M(M),
878886
N(N),
879887
K(K),
@@ -884,14 +892,12 @@ struct BrgemmKey {
884892
dt_a(dt_a),
885893
dt_b(dt_b),
886894
dt_c(dt_c),
887-
alpha(alpha),
888-
beta(beta) {}
895+
add_C(add_C) {}
889896
bool operator==(const BrgemmKey& other) const {
890897
return M == other.M && N == other.N && K == other.K &&
891898
batch_size == other.batch_size && lda == other.lda &&
892899
ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
893-
dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
894-
beta == other.beta;
900+
dt_b == other.dt_b && dt_c == other.dt_c && add_C == other.add_C;
895901
}
896902
};
897903

@@ -945,13 +951,13 @@ struct UnsafeUkernelKeyHasher {
945951

946952
template<>
947953
std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator()(const BrgemmKey& key) const {
948-
// Use beta, M, N, and K to compute hash to reduce the overhead as
949-
// batch size, alpha, and data types are unlikely to change within the same kernel and
950-
// leading dimensions are likely to be related to M, K, N or use fixed values.
951-
std::size_t h = std::hash<float>()(key.beta + 1);
952-
h = std::hash<int64_t>()(key.M) ^ (h << 1);
954+
// Use M, N, K add_C, and ldc to compute hash to reduce the overhead as
955+
// batch size and data types are unlikely to change within the same kernel and
956+
// lda/ldb are likely to be related to M, K, N or use fixed values.
957+
std::size_t h = std::hash<int64_t>()(key.M);
953958
h = std::hash<int64_t>()(key.N) ^ (h << 1);
954959
h = std::hash<int64_t>()(key.K) ^ (h << 1);
960+
h = std::hash<bool>()(key.add_C) ^ (h << 1);
955961
h = std::hash<int64_t>()(key.ldc) ^ (h << 1);
956962
return h;
957963
}
@@ -1000,9 +1006,9 @@ struct GemmHelper {
10001006
ScalarType dt_a,
10011007
ScalarType dt_b,
10021008
ScalarType dt_c,
1003-
const float alpha,
1004-
const float beta) {
1009+
const bool add_C) {
10051010
// Create brgemm
1011+
#if defined(ONEDNN_UKERNEL_1)
10061012
brg = dnnl::ukernel::brgemm(
10071013
M,
10081014
N,
@@ -1014,8 +1020,23 @@ struct GemmHelper {
10141020
get_dnnl_dtype(dt_a),
10151021
get_dnnl_dtype(dt_b),
10161022
get_dnnl_dtype(dt_c),
1017-
alpha,
1018-
beta);
1023+
1,
1024+
add_C ? 1 : 0);
1025+
#elif defined(ONEDNN_UKERNEL_2)
1026+
brg = dnnl::ukernel::brgemm(
1027+
M,
1028+
N,
1029+
K,
1030+
bs,
1031+
ld_a,
1032+
ld_b,
1033+
ld_c,
1034+
get_dnnl_dtype(dt_a),
1035+
get_dnnl_dtype(dt_b),
1036+
get_dnnl_dtype(dt_c));
1037+
brg.set_add_C(add_C);
1038+
brg.finalize();
1039+
#endif
10191040
// Create a scratchpad buffer for the brgemm execution
10201041
scratchpad = std::vector<uint8_t>(brg.get_scratchpad_size());
10211042
// Prepare default vector of pairs of tensors A and B offsets for each batch.
@@ -1037,8 +1058,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
10371058
int64_t ld_a,
10381059
int64_t ld_b,
10391060
int64_t ld_c,
1040-
const float alpha,
1041-
const float beta,
1061+
const bool add_C,
10421062
const scalar_t_a* A,
10431063
const scalar_t_b* B,
10441064
scalar_t_c* C) {
@@ -1053,8 +1073,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
10531073
c10::CppTypeToScalarType<scalar_t_a>::value,
10541074
c10::CppTypeToScalarType<scalar_t_b>::value,
10551075
c10::CppTypeToScalarType<scalar_t_c>::value,
1056-
alpha,
1057-
beta);
1076+
add_C);
10581077
// Fetch/create GemmHelper object
10591078
auto&& value = fetch_or_create(key, [&]() {
10601079
auto&& v = std::make_shared<GemmHelper>(
@@ -1068,13 +1087,14 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
10681087
c10::CppTypeToScalarType<scalar_t_a>::value,
10691088
c10::CppTypeToScalarType<scalar_t_b>::value,
10701089
c10::CppTypeToScalarType<scalar_t_c>::value,
1071-
alpha,
1072-
beta);
1090+
add_C);
10731091
(*v).brg.generate();
10741092
return std::move(v);
10751093
});
10761094
if (get_current() != value) {
1095+
#if defined(ONEDNN_UKERNEL_1)
10771096
dnnl::ukernel::brgemm::release_hw_context();
1097+
#endif
10781098
((*value).brg).set_hw_context();
10791099
get_current() = value;
10801100
}
@@ -1099,7 +1119,11 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
10991119
}
11001120
};
11011121

1122+
#if defined(ONEDNN_UKERNEL_1)
11021123
using pack_t = dnnl::ukernel::brgemm_pack_B;
1124+
#elif defined(ONEDNN_UKERNEL_2)
1125+
using pack_t = dnnl::ukernel::transform;
1126+
#endif
11031127
struct Pack : public KernelCache <PackKey, pack_t> {
11041128
static inline void call(
11051129
int64_t K,
@@ -1113,7 +1137,11 @@ struct Pack : public KernelCache <PackKey, pack_t> {
11131137
auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out);
11141138
auto&& pack = fetch_or_create(key, [&]() {
11151139
auto&& p = std::make_shared<pack_t>(
1140+
#if defined(ONEDNN_UKERNEL_1)
11161141
K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
1142+
#elif defined(ONEDNN_UKERNEL_2)
1143+
K, N, dnnl::ukernel::pack_type::no_trans, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
1144+
#endif
11171145
if (need_pack(dt_in)) {
11181146
(*p).generate();
11191147
}
@@ -1146,15 +1174,14 @@ void brgemm(
11461174
int64_t ld_a,
11471175
int64_t ld_b,
11481176
int64_t ld_c,
1149-
const float alpha,
1150-
const float beta,
1177+
const bool add_C,
11511178
const at::Half* A,
11521179
const at::Half* B,
11531180
float* C) {
1154-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1181+
#if defined(ONEDNN_UKERNEL_ENABLED)
11551182
if (Brgemm::device_check(ScalarType::Half)) {
11561183
Brgemm::call<at::Half, at::Half, float>(
1157-
M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C);
1184+
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
11581185
return;
11591186
}
11601187
#endif
@@ -1163,8 +1190,9 @@ void brgemm(
11631190
}
11641191

11651192
void brgemm_release() {
1166-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1193+
#if defined(ONEDNN_UKERNEL_ENABLED)
11671194
dnnl::ukernel::brgemm::release_hw_context();
1195+
Brgemm::get_current() = nullptr;
11681196
#endif
11691197
}
11701198

@@ -1177,15 +1205,15 @@ void pack(
11771205
ScalarType dt_out,
11781206
const void* in,
11791207
void* out) {
1180-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1208+
#if defined(ONEDNN_UKERNEL_ENABLED)
11811209
Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out);
11821210
#else
11831211
TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled");
11841212
#endif
11851213
}
11861214

11871215
bool need_pack(ScalarType dt_in) {
1188-
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1216+
#if defined(ONEDNN_UKERNEL_ENABLED)
11891217
return Pack::need_pack(dt_in);
11901218
#else
11911219
return false;

aten/src/ATen/native/CPUBlas.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
189189

190190
// Batch-reduce GEMM
191191
// Operates by the following formula:
192-
// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
192+
// C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size
193193
// A Base pointer to a tensor A.
194194
// B Base pointer to a tensor B.
195195
// C Pointer to a tensor C (accumulation buffer).
@@ -200,8 +200,7 @@ TORCH_API void brgemm(
200200
int64_t ld_a,
201201
int64_t ld_b,
202202
int64_t ld_c,
203-
const float alpha,
204-
const float beta,
203+
const bool add_C,
205204
const at::Half* A,
206205
const at::Half* B,
207206
float* C);

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,7 @@ void cpu_flash_attention(
603603
headSize_even ? qStrideM : eheadSize,
604604
packb_size,
605605
rkvBlockSize,
606-
1.f,
607-
0.f,
606+
false,
608607
!headSize_even
609608
? query_t_padding_ptr
610609
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
@@ -738,8 +737,7 @@ void cpu_flash_attention(
738737
ekvBlockSize,
739738
packb_size,
740739
rHeadSize,
741-
1.0,
742-
n == 0 ? 0.f : 1.f,
740+
n > 0,
743741
qk_reduced_data,
744742
value_reorder_ptr +
745743
i * num_head * kv_padding_size * rHeadSize +
@@ -791,10 +789,10 @@ void cpu_flash_attention(
791789
// Move to the next query
792790
data_index_step(i, batchSize, j, num_head, k, qSlice);
793791
}
792+
if (need_pack) {
793+
cpublas::brgemm_release();
794+
}
794795
});
795-
if (need_pack) {
796-
cpublas::brgemm_release();
797-
}
798796
}
799797

800798
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>

0 commit comments

Comments
 (0)