@@ -85,16 +85,26 @@ sdpa_config_t xehpg_h32_s64 = {16, 16, 16, 8, 4, 4, 2, 8};
85
85
sdpa_config_t xehpg_h32_s32 = {8 , 8 , 8 , 8 , 4 , 4 , 4 , 4 };
86
86
sdpa_config_t xehpg_h32_2nd = {8 , 32 , 16 , 8 , 8 , 1 , 2 , 4 };
87
87
88
+ sdpa_config_t xehpg_q_h32 = {32 , 16 , 16 , 16 , 2 , 8 , 2 , 8 };
89
+ sdpa_config_t xehpg_q_h32_2nd = {32 , 16 , 8 , 8 , 8 , 1 , 4 , 2 };
90
+
88
91
sdpa_config_t xehpg_h64 = {32 , 16 , 16 , 16 , 4 , 8 , 4 , 8 };
89
92
sdpa_config_t xehpg_h64_s128 = {16 , 16 , 16 , 16 , 4 , 8 , 4 , 8 };
90
93
sdpa_config_t xehpg_h64_s64 = {32 , 16 , 16 , 8 , 8 , 4 , 4 , 8 };
91
94
sdpa_config_t xehpg_h64_2nd = {8 , 16 , 16 , 8 , 8 , 1 , 4 , 2 };
92
95
96
+ sdpa_config_t xehpg_q_h64 = {32 , 16 , 16 , 16 , 4 , 4 , 4 , 4 };
97
+ sdpa_config_t xehpg_q_h64_2nd = {16 , 16 , 8 , 8 , 16 , 1 , 8 , 2 };
98
+
93
99
sdpa_config_t xehpg_h128 = {16 , 16 , 32 , 8 , 8 , 4 , 4 , 8 };
94
100
sdpa_config_t xehpg_h128_s32 = {16 , 16 , 16 , 8 , 16 , 2 , 8 , 4 };
95
101
sdpa_config_t xehpg_h128_2nd = {8 , 16 , 16 , 8 , 16 , 1 , 8 , 2 };
96
102
sdpa_config_t xehpg_h128_s256_2nd = {8 , 16 , 32 , 8 , 8 , 1 , 4 , 2 };
97
103
104
+ sdpa_config_t xehpg_q_h128 = {32 , 16 , 16 , 16 , 8 , 4 , 8 , 4 };
105
+ sdpa_config_t xehpg_q_h128_2nd = {32 , 16 , 16 , 8 , 16 , 1 , 8 , 2 };
106
+ sdpa_config_t xehpg_q_h128_s64_2nd = {16 , 16 , 16 , 8 , 16 , 1 , 8 , 2 };
107
+
98
108
sdpa_config_t xehpg_h256 = {16 , 16 , 32 , 8 , 16 , 2 , 8 , 4 };
99
109
sdpa_config_t xehpg_h256_s128 = {8 , 16 , 32 , 16 , 8 , 4 , 8 , 4 };
100
110
sdpa_config_t xehpg_h256_s32 = {8 , 16 , 32 , 8 , 16 , 2 , 8 , 4 };
@@ -112,28 +122,52 @@ sdpa_config_t xehpc_h64_s32 = {16, 16, 16, 16, 4, 2, 4, 2};
112
122
sdpa_config_t xehpc_h64_2nd = {32 , 32 , 32 , 16 , 4 , 1 , 2 , 2 };
113
123
sdpa_config_t xehpc_h64_s64_2nd = {16 , 16 , 16 , 16 , 4 , 1 , 4 , 1 };
114
124
125
+ sdpa_config_t xehpc_q_h64 = {16 , 64 , 32 , 16 , 8 , 4 , 2 , 16 };
126
+
115
127
sdpa_config_t xehpc_h128 = {16 , 64 , 32 , 16 , 16 , 2 , 4 , 8 };
116
128
sdpa_config_t xehpc_h128_s64 = {16 , 32 , 32 , 32 , 4 , 2 , 4 , 2 };
117
129
sdpa_config_t xehpc_h128_s32 = {16 , 16 , 16 , 16 , 8 , 2 , 8 , 2 };
118
130
sdpa_config_t xehpc_h128_2nd = {32 , 32 , 32 , 16 , 8 , 1 , 4 , 2 };
119
131
132
+ sdpa_config_t xehpc_q_h128 = {16 , 64 , 16 , 32 , 16 , 2 , 8 , 4 };
133
+ sdpa_config_t xehpc_q_h128_s64 = {16 , 16 , 32 , 16 , 4 , 4 , 4 , 4 };
134
+ sdpa_config_t xehpc_q_h128_s32 = {16 , 16 , 32 , 16 , 4 , 2 , 4 , 2 };
135
+ sdpa_config_t xehpc_q_h128_2nd = {32 , 32 , 16 , 32 , 4 , 1 , 4 , 1 };
136
+ sdpa_config_t xehpc_q_h128_s32_2nd = {16 , 32 , 16 , 16 , 8 , 1 , 4 , 2 };
137
+
120
138
sdpa_config_t xehpc_h256 = {16 , 32 , 32 , 32 , 8 , 4 , 8 , 4 };
121
139
sdpa_config_t xehpc_h256_s64 = {16 , 32 , 32 , 32 , 8 , 1 , 8 , 1 };
122
140
sdpa_config_t xehpc_h256_2nd = {16 , 16 , 16 , 16 , 16 , 1 , 16 , 1 };
123
141
124
- sdpa_config_t *choose_config_xehpg (int head_size, int seq, bool thin_q) {
142
+ sdpa_config_t *choose_config_xehpg (int head_size, int seq, bool thin_q, bool quantized ) {
125
143
if (head_size <= 32 ) {
144
+ if (quantized && seq >= 128 ) {
145
+ if (thin_q) return &xehpg_q_h32_2nd;
146
+ return &xehpg_q_h32;
147
+ }
126
148
if (thin_q) return &xehpg_h32_2nd;
127
149
if (seq <= 32 ) return &xehpg_h32_s32;
128
150
if (seq <= 64 ) return &xehpg_h32_s64;
129
151
if (seq <= 256 ) return &xehpg_h32_s256;
130
152
return &xehpg_h32;
131
153
} else if (head_size <= 64 ) {
154
+ if (quantized) {
155
+ if (thin_q) return &xehpg_q_h64_2nd;
156
+ return &xehpg_q_h64;
157
+ }
132
158
if (thin_q) return &xehpg_h64_2nd;
133
159
if (seq <= 64 ) return &xehpg_h64_s64;
134
160
if (seq <= 128 ) return &xehpg_h64_s128;
135
161
return &xehpg_h64;
136
162
} else if (head_size <= 128 ) {
163
+ if (quantized) {
164
+ if (thin_q) {
165
+ if (seq <= 64 ) return &xehpg_q_h128_s64_2nd;
166
+ return &xehpg_q_h128_2nd;
167
+ }
168
+ if (seq <= 32 ) return &xehpg_h128_s32;
169
+ return &xehpg_q_h128;
170
+ }
137
171
if (thin_q) {
138
172
if (seq <= 256 ) return &xehpg_h128_s256_2nd;
139
173
return &xehpg_h128_2nd;
@@ -153,7 +187,7 @@ sdpa_config_t *choose_config_xehpg(int head_size, int seq, bool thin_q) {
153
187
return nullptr ;
154
188
}
155
189
156
- sdpa_config_t *choose_config_xehpc (int head_size, int seq, bool thin_q) {
190
+ sdpa_config_t *choose_config_xehpc (int head_size, int seq, bool thin_q, bool quantized ) {
157
191
if (head_size <= 32 ) {
158
192
if (thin_q) return &xehpc_h32_2nd;
159
193
if (seq <= 32 ) return &xehpc_h32_s32;
@@ -163,10 +197,20 @@ sdpa_config_t *choose_config_xehpc(int head_size, int seq, bool thin_q) {
163
197
if (seq <= 64 ) return &xehpc_h64_s64_2nd;
164
198
return &xehpc_h64_2nd;
165
199
}
200
+ if (quantized && seq >= 256 ) return &xehpc_q_h64;
166
201
if (seq <= 32 ) return &xehpc_h64_s32;
167
202
if (seq <= 64 ) return &xehpc_h64_s64;
168
203
return &xehpc_h64;
169
204
} else if (head_size <= 128 ) {
205
+ if (quantized) {
206
+ if (thin_q) {
207
+ if (seq <= 32 ) return &xehpc_q_h128_s32_2nd;
208
+ return &xehpc_q_h128_2nd;
209
+ }
210
+ if (seq <= 32 ) return &xehpc_q_h128_s32;
211
+ if (seq <= 64 ) return &xehpc_q_h128_s64;
212
+ return &xehpc_q_h128;
213
+ }
170
214
if (thin_q) return &xehpc_h128_2nd;
171
215
if (seq <= 32 ) return &xehpc_h128_s32;
172
216
if (seq <= 64 ) return &xehpc_h128_s64;
@@ -207,15 +251,18 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
207
251
sdpa_config_t *config = nullptr ;
208
252
bool thin_q = (!n_queries.is_dynamic && (n_queries.v <= 16 )) || !is_prefill;
209
253
254
+ bool is_quantized = (K.GetDType () == Datatype::UINT8 || K.GetDType () == Datatype::INT8) ||
255
+ (V.GetDType () == Datatype::UINT8 || V.GetDType () == Datatype::INT8);
256
+
210
257
switch (params.engineInfo .arch ) {
211
258
case gpu_arch::xe_hpg: {
212
- config = choose_config_xehpg (static_cast <int32_t >(head_size), static_cast <int32_t >(n_keys.v ), thin_q);
259
+ config = choose_config_xehpg (static_cast <int32_t >(head_size), static_cast <int32_t >(n_keys.v ), thin_q, is_quantized );
213
260
break ;
214
261
}
215
262
case gpu_arch::xe_hpc:
216
263
case gpu_arch::xe2:
217
264
case gpu_arch::xe3: {
218
- config = choose_config_xehpc (static_cast <int32_t >(head_size), static_cast <int32_t >(n_keys.v ), thin_q);
265
+ config = choose_config_xehpc (static_cast <int32_t >(head_size), static_cast <int32_t >(n_keys.v ), thin_q, is_quantized );
219
266
break ;
220
267
}
221
268
default : break ;
@@ -330,7 +377,7 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
330
377
}
331
378
332
379
if (params.conf .is_kv_compressed ) {
333
- problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) ? 1 : params.conf .head_size ;
380
+ problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) ? 1 : micro::rnd_up_pow2 ( params.conf .head_size ) ;
334
381
problem_vs.aqGroupK = 1 ;
335
382
}
336
383
0 commit comments