@@ -45,12 +45,21 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
45
45
#endif // USE_FBGEMM
46
46
47
47
#if AT_MKLDNN_ENABLED()
48
- #include < oneapi/dnnl/dnnl_version.h>
49
- #endif // oneDNN
50
-
51
- #define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5 )
48
+ #include < ideep.hpp>
49
+ // Add uKernel API versioning to be compatible with different oneDNN versions
50
+ // oneDNN 3.6.x updates the ukernel APIs of brgemm and brgemm_pack_B
51
+ // brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C
52
+ #if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5)
53
+ #define ONEDNN_UKERNEL_1
54
+ #elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6)
55
+ #define ONEDNN_UKERNEL_2
56
+ #endif
57
+ #if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))))
58
+ #define ONEDNN_UKERNEL_ENABLED
59
+ #endif
60
+ #endif // AT_MKLDNN_ENABLED()
52
61
53
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
62
+ #if defined(ONEDNN_UKERNEL_ENABLED )
54
63
#include < oneapi/dnnl/dnnl_ukernel.hpp>
55
64
#include < oneapi/dnnl/dnnl.hpp>
56
65
#endif // oneDNN BRGEMM
@@ -847,7 +856,7 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
847
856
}
848
857
849
858
// oneDNN BRGEMM
850
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
859
+ #if defined(ONEDNN_UKERNEL_ENABLED )
851
860
struct BrgemmKey {
852
861
int64_t M;
853
862
int64_t N;
@@ -859,8 +868,8 @@ struct BrgemmKey {
859
868
ScalarType dt_a;
860
869
ScalarType dt_b;
861
870
ScalarType dt_c;
862
- float alpha ;
863
- float beta;
871
+ bool add_C ;
872
+
864
873
BrgemmKey (
865
874
int64_t M,
866
875
int64_t N,
@@ -872,8 +881,7 @@ struct BrgemmKey {
872
881
ScalarType dt_a,
873
882
ScalarType dt_b,
874
883
ScalarType dt_c,
875
- float alpha,
876
- float beta)
884
+ bool add_C)
877
885
: M(M),
878
886
N (N),
879
887
K(K),
@@ -884,14 +892,12 @@ struct BrgemmKey {
884
892
dt_a(dt_a),
885
893
dt_b(dt_b),
886
894
dt_c(dt_c),
887
- alpha(alpha),
888
- beta(beta) {}
895
+ add_C(add_C) {}
889
896
bool operator ==(const BrgemmKey& other) const {
890
897
return M == other.M && N == other.N && K == other.K &&
891
898
batch_size == other.batch_size && lda == other.lda &&
892
899
ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
893
- dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
894
- beta == other.beta ;
900
+ dt_b == other.dt_b && dt_c == other.dt_c && add_C == other.add_C ;
895
901
}
896
902
};
897
903
@@ -945,13 +951,13 @@ struct UnsafeUkernelKeyHasher {
945
951
946
952
template <>
947
953
std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator ()(const BrgemmKey& key) const {
948
- // Use beta, M, N, and K to compute hash to reduce the overhead as
949
- // batch size, alpha, and data types are unlikely to change within the same kernel and
950
- // leading dimensions are likely to be related to M, K, N or use fixed values.
951
- std::size_t h = std::hash<float >()(key.beta + 1 );
952
- h = std::hash<int64_t >()(key.M ) ^ (h << 1 );
954
+ // Use M, N, K add_C, and ldc to compute hash to reduce the overhead as
955
+ // batch size and data types are unlikely to change within the same kernel and
956
+ // lda/ldb are likely to be related to M, K, N or use fixed values.
957
+ std::size_t h = std::hash<int64_t >()(key.M );
953
958
h = std::hash<int64_t >()(key.N ) ^ (h << 1 );
954
959
h = std::hash<int64_t >()(key.K ) ^ (h << 1 );
960
+ h = std::hash<bool >()(key.add_C ) ^ (h << 1 );
955
961
h = std::hash<int64_t >()(key.ldc ) ^ (h << 1 );
956
962
return h;
957
963
}
@@ -1000,9 +1006,9 @@ struct GemmHelper {
1000
1006
ScalarType dt_a,
1001
1007
ScalarType dt_b,
1002
1008
ScalarType dt_c,
1003
- const float alpha,
1004
- const float beta) {
1009
+ const bool add_C) {
1005
1010
// Create brgemm
1011
+ #if defined(ONEDNN_UKERNEL_1)
1006
1012
brg = dnnl::ukernel::brgemm (
1007
1013
M,
1008
1014
N,
@@ -1014,8 +1020,23 @@ struct GemmHelper {
1014
1020
get_dnnl_dtype (dt_a),
1015
1021
get_dnnl_dtype (dt_b),
1016
1022
get_dnnl_dtype (dt_c),
1017
- alpha,
1018
- beta);
1023
+ 1 ,
1024
+ add_C ? 1 : 0 );
1025
+ #elif defined(ONEDNN_UKERNEL_2)
1026
+ brg = dnnl::ukernel::brgemm (
1027
+ M,
1028
+ N,
1029
+ K,
1030
+ bs,
1031
+ ld_a,
1032
+ ld_b,
1033
+ ld_c,
1034
+ get_dnnl_dtype (dt_a),
1035
+ get_dnnl_dtype (dt_b),
1036
+ get_dnnl_dtype (dt_c));
1037
+ brg.set_add_C (add_C);
1038
+ brg.finalize ();
1039
+ #endif
1019
1040
// Create a scratchpad buffer for the brgemm execution
1020
1041
scratchpad = std::vector<uint8_t >(brg.get_scratchpad_size ());
1021
1042
// Prepare default vector of pairs of tensors A and B offsets for each batch.
@@ -1037,8 +1058,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1037
1058
int64_t ld_a,
1038
1059
int64_t ld_b,
1039
1060
int64_t ld_c,
1040
- const float alpha,
1041
- const float beta,
1061
+ const bool add_C,
1042
1062
const scalar_t_a* A,
1043
1063
const scalar_t_b* B,
1044
1064
scalar_t_c* C) {
@@ -1053,8 +1073,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1053
1073
c10::CppTypeToScalarType<scalar_t_a>::value,
1054
1074
c10::CppTypeToScalarType<scalar_t_b>::value,
1055
1075
c10::CppTypeToScalarType<scalar_t_c>::value,
1056
- alpha,
1057
- beta);
1076
+ add_C);
1058
1077
// Fetch/create GemmHelper object
1059
1078
auto && value = fetch_or_create (key, [&]() {
1060
1079
auto && v = std::make_shared<GemmHelper>(
@@ -1068,13 +1087,14 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1068
1087
c10::CppTypeToScalarType<scalar_t_a>::value,
1069
1088
c10::CppTypeToScalarType<scalar_t_b>::value,
1070
1089
c10::CppTypeToScalarType<scalar_t_c>::value,
1071
- alpha,
1072
- beta);
1090
+ add_C);
1073
1091
(*v).brg .generate ();
1074
1092
return std::move (v);
1075
1093
});
1076
1094
if (get_current () != value) {
1095
+ #if defined(ONEDNN_UKERNEL_1)
1077
1096
dnnl::ukernel::brgemm::release_hw_context ();
1097
+ #endif
1078
1098
((*value).brg ).set_hw_context ();
1079
1099
get_current () = value;
1080
1100
}
@@ -1099,7 +1119,11 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1099
1119
}
1100
1120
};
1101
1121
1122
+ #if defined(ONEDNN_UKERNEL_1)
1102
1123
using pack_t = dnnl::ukernel::brgemm_pack_B;
1124
+ #elif defined(ONEDNN_UKERNEL_2)
1125
+ using pack_t = dnnl::ukernel::transform;
1126
+ #endif
1103
1127
struct Pack : public KernelCache <PackKey, pack_t > {
1104
1128
static inline void call (
1105
1129
int64_t K,
@@ -1113,7 +1137,11 @@ struct Pack : public KernelCache <PackKey, pack_t> {
1113
1137
auto && key = PackKey (K, N, ld_in, ld_out, dt_in, dt_out);
1114
1138
auto && pack = fetch_or_create (key, [&]() {
1115
1139
auto && p = std::make_shared<pack_t >(
1140
+ #if defined(ONEDNN_UKERNEL_1)
1116
1141
K, N, ld_in, ld_out, get_dnnl_dtype (dt_in), get_dnnl_dtype (dt_out));
1142
+ #elif defined(ONEDNN_UKERNEL_2)
1143
+ K, N, dnnl::ukernel::pack_type::no_trans, ld_in, ld_out, get_dnnl_dtype (dt_in), get_dnnl_dtype (dt_out));
1144
+ #endif
1117
1145
if (need_pack (dt_in)) {
1118
1146
(*p).generate ();
1119
1147
}
@@ -1146,15 +1174,14 @@ void brgemm(
1146
1174
int64_t ld_a,
1147
1175
int64_t ld_b,
1148
1176
int64_t ld_c,
1149
- const float alpha,
1150
- const float beta,
1177
+ const bool add_C,
1151
1178
const at::Half* A,
1152
1179
const at::Half* B,
1153
1180
float * C) {
1154
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
1181
+ #if defined(ONEDNN_UKERNEL_ENABLED )
1155
1182
if (Brgemm::device_check (ScalarType::Half)) {
1156
1183
Brgemm::call<at::Half, at::Half, float >(
1157
- M, N, K, ld_a, ld_b, ld_c, alpha, beta , A, B, C);
1184
+ M, N, K, ld_a, ld_b, ld_c, add_C , A, B, C);
1158
1185
return ;
1159
1186
}
1160
1187
#endif
@@ -1163,8 +1190,9 @@ void brgemm(
1163
1190
}
1164
1191
1165
1192
void brgemm_release () {
1166
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
1193
+ #if defined(ONEDNN_UKERNEL_ENABLED )
1167
1194
dnnl::ukernel::brgemm::release_hw_context ();
1195
+ Brgemm::get_current () = nullptr ;
1168
1196
#endif
1169
1197
}
1170
1198
@@ -1177,15 +1205,15 @@ void pack(
1177
1205
ScalarType dt_out,
1178
1206
const void * in,
1179
1207
void * out) {
1180
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
1208
+ #if defined(ONEDNN_UKERNEL_ENABLED )
1181
1209
Pack::call (K, N, ld_in, ld_out, dt_in, dt_out, in, out);
1182
1210
#else
1183
1211
TORCH_CHECK (false , " pack is only supported on X64 with oneDNN ukernel enabled" );
1184
1212
#endif
1185
1213
}
1186
1214
1187
1215
bool need_pack (ScalarType dt_in) {
1188
- #if ONEDNN_UKERNEL_ENABLED && ( defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)) )
1216
+ #if defined(ONEDNN_UKERNEL_ENABLED )
1189
1217
return Pack::need_pack (dt_in);
1190
1218
#else
1191
1219
return false ;
0 commit comments