Skip to content

Commit 9991950

Browse files
author
Vladimir Paramuzov
authored
Port changes from uxlfoundation/oneDNN#2372
1 parent 940aaa7 commit 9991950

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp

+52-5
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,26 @@ sdpa_config_t xehpg_h32_s64 = {16, 16, 16, 8, 4, 4, 2, 8};
8585
sdpa_config_t xehpg_h32_s32 = {8, 8, 8, 8, 4, 4, 4, 4};
8686
sdpa_config_t xehpg_h32_2nd = {8, 32, 16, 8, 8, 1, 2, 4};
8787

88+
sdpa_config_t xehpg_q_h32 = {32, 16, 16, 16, 2, 8, 2, 8};
89+
sdpa_config_t xehpg_q_h32_2nd = {32, 16, 8, 8, 8, 1, 4, 2};
90+
8891
sdpa_config_t xehpg_h64 = {32, 16, 16, 16, 4, 8, 4, 8};
8992
sdpa_config_t xehpg_h64_s128 = {16, 16, 16, 16, 4, 8, 4, 8};
9093
sdpa_config_t xehpg_h64_s64 = {32, 16, 16, 8, 8, 4, 4, 8};
9194
sdpa_config_t xehpg_h64_2nd = {8, 16, 16, 8, 8, 1, 4, 2};
9295

96+
sdpa_config_t xehpg_q_h64 = {32, 16, 16, 16, 4, 4, 4, 4};
97+
sdpa_config_t xehpg_q_h64_2nd = {16, 16, 8, 8, 16, 1, 8, 2};
98+
9399
sdpa_config_t xehpg_h128 = {16, 16, 32, 8, 8, 4, 4, 8};
94100
sdpa_config_t xehpg_h128_s32 = {16, 16, 16, 8, 16, 2, 8, 4};
95101
sdpa_config_t xehpg_h128_2nd = {8, 16, 16, 8, 16, 1, 8, 2};
96102
sdpa_config_t xehpg_h128_s256_2nd = {8, 16, 32, 8, 8, 1, 4, 2};
97103

104+
sdpa_config_t xehpg_q_h128 = {32, 16, 16, 16, 8, 4, 8, 4};
105+
sdpa_config_t xehpg_q_h128_2nd = {32, 16, 16, 8, 16, 1, 8, 2};
106+
sdpa_config_t xehpg_q_h128_s64_2nd = {16, 16, 16, 8, 16, 1, 8, 2};
107+
98108
sdpa_config_t xehpg_h256 = {16, 16, 32, 8, 16, 2, 8, 4};
99109
sdpa_config_t xehpg_h256_s128 = {8, 16, 32, 16, 8, 4, 8, 4};
100110
sdpa_config_t xehpg_h256_s32 = {8, 16, 32, 8, 16, 2, 8, 4};
@@ -112,28 +122,52 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
112122
sdpa_config_t xehpc_h64_2nd = {32, 32, 32, 16, 4, 1, 2, 2};
113123
sdpa_config_t xehpc_h64_s64_2nd = {16, 16, 16, 16, 4, 1, 4, 1};
114124

125+
sdpa_config_t xehpc_q_h64 = {16, 64, 32, 16, 8, 4, 2, 16};
126+
115127
sdpa_config_t xehpc_h128 = {16, 64, 32, 16, 16, 2, 4, 8};
116128
sdpa_config_t xehpc_h128_s64 = {16, 32, 32, 32, 4, 2, 4, 2};
117129
sdpa_config_t xehpc_h128_s32 = {16, 16, 16, 16, 8, 2, 8, 2};
118130
sdpa_config_t xehpc_h128_2nd = {32, 32, 32, 16, 8, 1, 4, 2};
119131

132+
sdpa_config_t xehpc_q_h128 = {16, 64, 16, 32, 16, 2, 8, 4};
133+
sdpa_config_t xehpc_q_h128_s64 = {16, 16, 32, 16, 4, 4, 4, 4};
134+
sdpa_config_t xehpc_q_h128_s32 = {16, 16, 32, 16, 4, 2, 4, 2};
135+
sdpa_config_t xehpc_q_h128_2nd = {32, 32, 16, 32, 4, 1, 4, 1};
136+
sdpa_config_t xehpc_q_h128_s32_2nd = {16, 32, 16, 16, 8, 1, 4, 2};
137+
120138
sdpa_config_t xehpc_h256 = {16, 32, 32, 32, 8, 4, 8, 4};
121139
sdpa_config_t xehpc_h256_s64 = {16, 32, 32, 32, 8, 1, 8, 1};
122140
sdpa_config_t xehpc_h256_2nd = {16, 16, 16, 16, 16, 1, 16, 1};
123141

124-
sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q) {
142+
sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q, bool quantized) {
125143
if (head_size <= 32) {
144+
if (quantized && seq >= 128) {
145+
if (thin_q) return &xehpg_q_h32_2nd;
146+
return &xehpg_q_h32;
147+
}
126148
if (thin_q) return &xehpg_h32_2nd;
127149
if (seq <= 32) return &xehpg_h32_s32;
128150
if (seq <= 64) return &xehpg_h32_s64;
129151
if (seq <= 256) return &xehpg_h32_s256;
130152
return &xehpg_h32;
131153
} else if (head_size <= 64) {
154+
if (quantized) {
155+
if (thin_q) return &xehpg_q_h64_2nd;
156+
return &xehpg_q_h64;
157+
}
132158
if (thin_q) return &xehpg_h64_2nd;
133159
if (seq <= 64) return &xehpg_h64_s64;
134160
if (seq <= 128) return &xehpg_h64_s128;
135161
return &xehpg_h64;
136162
} else if (head_size <= 128) {
163+
if (quantized) {
164+
if (thin_q) {
165+
if (seq <= 64) return &xehpg_q_h128_s64_2nd;
166+
return &xehpg_q_h128_2nd;
167+
}
168+
if (seq <= 32) return &xehpg_h128_s32;
169+
return &xehpg_q_h128;
170+
}
137171
if (thin_q) {
138172
if (seq <= 256) return &xehpg_h128_s256_2nd;
139173
return &xehpg_h128_2nd;
@@ -153,7 +187,7 @@ sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q) {
153187
return nullptr;
154188
}
155189

156-
sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) {
190+
sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q, bool quantized) {
157191
if (head_size <= 32) {
158192
if (thin_q) return &xehpc_h32_2nd;
159193
if (seq <= 32) return &xehpc_h32_s32;
@@ -163,10 +197,20 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) {
163197
if (seq <= 64) return &xehpc_h64_s64_2nd;
164198
return &xehpc_h64_2nd;
165199
}
200+
if (quantized && seq >= 256) return &xehpc_q_h64;
166201
if (seq <= 32) return &xehpc_h64_s32;
167202
if (seq <= 64) return &xehpc_h64_s64;
168203
return &xehpc_h64;
169204
} else if (head_size <= 128) {
205+
if (quantized) {
206+
if (thin_q) {
207+
if (seq <= 32) return &xehpc_q_h128_s32_2nd;
208+
return &xehpc_q_h128_2nd;
209+
}
210+
if (seq <= 32) return &xehpc_q_h128_s32;
211+
if (seq <= 64) return &xehpc_q_h128_s64;
212+
return &xehpc_q_h128;
213+
}
170214
if (thin_q) return &xehpc_h128_2nd;
171215
if (seq <= 32) return &xehpc_h128_s32;
172216
if (seq <= 64) return &xehpc_h128_s64;
@@ -207,15 +251,18 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
207251
sdpa_config_t *config = nullptr;
208252
bool thin_q = (!n_queries.is_dynamic && (n_queries.v <= 16)) || !is_prefill;
209253

254+
bool is_quantized = (K.GetDType() == Datatype::UINT8 || K.GetDType() == Datatype::INT8) ||
255+
(V.GetDType() == Datatype::UINT8 || V.GetDType() == Datatype::INT8);
256+
210257
switch (params.engineInfo.arch) {
211258
case gpu_arch::xe_hpg: {
212-
config = choose_config_xehpg(static_cast<int32_t>(head_size), static_cast<int32_t>(n_keys.v), thin_q);
259+
config = choose_config_xehpg(static_cast<int32_t>(head_size), static_cast<int32_t>(n_keys.v), thin_q, is_quantized);
213260
break;
214261
}
215262
case gpu_arch::xe_hpc:
216263
case gpu_arch::xe2:
217264
case gpu_arch::xe3: {
218-
config = choose_config_xehpc(static_cast<int32_t>(head_size), static_cast<int32_t>(n_keys.v), thin_q);
265+
config = choose_config_xehpc(static_cast<int32_t>(head_size), static_cast<int32_t>(n_keys.v), thin_q, is_quantized);
219266
break;
220267
}
221268
default: break;
@@ -330,7 +377,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
330377
}
331378

332379
if (params.conf.is_kv_compressed) {
333-
problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) ? 1 : params.conf.head_size;
380+
problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) ? 1 : micro::rnd_up_pow2(params.conf.head_size);
334381
problem_vs.aqGroupK = 1;
335382
}
336383

0 commit comments

Comments
 (0)