Skip to content

Commit 47c7e5d

Browse files
committed
Speculative decoding related changes
1 parent cb6b03e commit 47c7e5d

File tree

9 files changed

+178
-20
lines changed

9 files changed

+178
-20
lines changed

src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
801801
v_input.reset(inputs[2]);
802802
present_key.reset(presentk_input);
803803
present_value.reset(presentv_input);
804+
// std::cout << "is_PA=" << is_pagedattn << " q_input=" << inputs[0]->getShape().toPartialShape() << " k_input" << inputs[1]->getShape().toPartialShape()
805+
// << " present_key=" << presentk_input->getShape().toPartialShape() << " present_value=" << presentv_input->getShape().toPartialShape() << "\n";
804806
if (is_pagedattn) {
805807
is_prompt = *inputs[ID_IS_PROMPT]->getDataAs<uint8_t>() == 1;
806808
//auto max_context_len = static_cast<size_t>(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs<int32_t>());
@@ -825,13 +827,17 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
825827
// L0 in each batch may be different
826828
L0 = 0;
827829

830+
// std::cout << "Assert 1\n";
828831
q_input.assert_dims({B, L1, H * S});
829832
if (!is_prompt) {
833+
// std::cout << "Assert 2\n";
830834
context_lens.assert_dims({B});
835+
// std::cout << "Assert 3\n";
831836
beam_table.assert_dims({B, 0}, true);
832837
} else {
833838
sliding_window = static_cast<size_t>(*inputs[ID_SLIDING_WINDOW]->getDataAs<int32_t>());
834839
}
840+
// std::cout << "Assert 4\n";
835841
output_emb.assert_dims({B, L1, H * S});
836842
q_input = q_input.reshape({B, L1, H, S}).permute({0, 2, 1, 3});
837843
k_input = k_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3});
@@ -872,22 +878,31 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
872878
auto Hk = k_input.size(1);
873879

874880
if (fuse_concat) {
881+
// std::cout << "Assert 5\n";
875882
k_input.assert_dims({B, Hk, L1, S});
883+
// std::cout << "Assert 6\n";
876884
v_input.assert_dims({B, Hk, L1, S});
877885
} else {
886+
// std::cout << "Assert 7\n";
878887
k_input.assert_dims({B, Hk, L0 + L1, S});
888+
// std::cout << "Assert 8\n";
879889
v_input.assert_dims({B, Hk, L0 + L1, S});
880890
}
891+
// std::cout << "Assert 9\n";
881892
present_key.assert_dims({B, Hk, L0 + L1, S});
893+
// std::cout << "Assert 10\n";
882894
present_value.assert_dims({B, Hk, L0 + L1, S});
883-
if (beam_table)
895+
if (beam_table) {
896+
// std::cout << "Assert 11\n";
884897
beam_table.assert_dims({B, L0 + L1});
898+
}
885899
}
886900

887901
bool auto_causal;
888902
bool use_attn_mask;
889903
if (fuse_causal_attn) {
890904
assert(attn_mask);
905+
// std::cout << "Assert 12\n";
891906
attn_mask.assert_dims({B, 1, L1, L0 + L1});
892907
auto_causal = true;
893908
use_attn_mask = true;

src/plugins/intel_gpu/src/graph/graph_optimizer/mark_runtime_skippable_nodes.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ void mark_runtime_skippable_nodes::run(program& p) {
101101
if (node.is_output()
102102
|| node.has_fused_primitives()
103103
|| (impl_params->get_input_layout(0).format != impl_params->get_output_layout().format)
104-
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type))
104+
|| (impl_params->get_input_layout(0).data_type != impl_params->get_output_layout().data_type)
105+
|| node.is_in_shape_of_subgraph())
105106
return;
106107

107108
if (node.is_dynamic()) {

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+14-8
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,14 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
7171
instance.input_memory_ptr(4) /* value_cache */ };
7272
} else if (stage == Stage::SDPA) {
7373
if (kernel_idx == 0) {
74-
args.inputs = { instance.input_memory_ptr(0), /* query */
75-
instance.input_memory_ptr(3), /* key_cache */
76-
instance.input_memory_ptr(4), /* value_cache */
77-
instance.input_memory_ptr(7), /* max_context_len */
78-
instance.input_memory_ptr(8), /* context_lens */
79-
instance.input_memory_ptr(9), /* block_tables */
80-
instance.input_memory_ptr(10) /* scale */ };
74+
args.inputs = { instance.input_memory_ptr(0), /* query */
75+
instance.input_memory_ptr(3), /* key_cache */
76+
instance.input_memory_ptr(4), /* value_cache */
77+
instance.input_memory_ptr(7), /* max_context_len */
78+
instance.input_memory_ptr(8), /* context_lens */
79+
instance.input_memory_ptr(9), /* block_tables */
80+
instance.input_memory_ptr(10), /* scale */
81+
instance.input_memory_ptr(5) /* is_prompt */ };
8182
} else {
8283
args.inputs = { instance.input_memory_ptr(8), /* context_lens */ };
8384
}
@@ -212,10 +213,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
212213
static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) {
213214
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);
214215

215-
const auto inputs_count = 7;
216+
const auto inputs_count = 8;
216217
const auto query_layout = impl_param.get_input_layout(0);
217218
const auto key_cache_layout = impl_param.get_input_layout(3);
218219
const auto value_cache_layout = impl_param.get_input_layout(4);
220+
const auto is_prompt_layout = impl_param.get_input_layout(5);
219221
const auto max_context_len_layout = impl_param.get_input_layout(7);
220222
const auto context_lens_layout = impl_param.get_input_layout(8);
221223
const auto block_tables_layout = impl_param.get_input_layout(9);
@@ -228,6 +230,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
228230
params.inputs[4] = convert_data_tensor(context_lens_layout);
229231
params.inputs[5] = convert_data_tensor(block_tables_layout);
230232
params.inputs[6] = convert_data_tensor(scale_layout);
233+
params.inputs[7] = convert_data_tensor(is_prompt_layout);
231234

232235
params.configuration = get_sdpa_configuration(impl_param);
233236
if (!is_dynamic) {
@@ -240,6 +243,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
240243
mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream());
241244
bool is_prompt_stage = is_prompt_stage_mem_lock[0];
242245

246+
243247
if (is_prompt_stage) {
244248
// Use number of slots for KV cache as a maximum context length for the first iteration
245249
auto slot_mapping = impl_param.get_input_layout(6);
@@ -249,6 +253,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
249253
mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream());
250254
params.configuration.max_context_len = max_context_len_mem_lock[0];
251255
}
256+
// std::cout << "is_prompt_stage=" << is_prompt_stage << " params.configuration.max_context_len=" << params.configuration.max_context_len << "\n";
252257
}
253258

254259
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
@@ -261,6 +266,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
261266
{4, in_offsets_map.at(8)},
262267
{5, in_offsets_map.at(9)},
263268
{6, in_offsets_map.at(10)},
269+
{7, in_offsets_map.at(5)},
264270
};
265271
std::map<size_t, size_t> out_tensor_to_offset_map = {
266272
{0, out_offsets_map.at(0)},

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
15481548
_outputs = allocate_outputs();
15491549
}
15501550
}
1551+
if (_node) {
1552+
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
1553+
}
15511554
if (_impl) {
15521555
_impl->set_node_params(node);
15531556
if (_impl->is_dynamic() && !_impl->is_cpu()) {

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ KERNEL(pa_kv_cache_update)(
2626
const uint block_index = slot_idx / KV_CACHE_BLOCK_SIZE;
2727
const uint block_offset = slot_idx % KV_CACHE_BLOCK_SIZE;
2828

29+
30+
// if (batch_idx == 0 && hidden_idx == 0) {
31+
// printf("Update kv_cache %d: block_dx=%d offset=%d, slot_idx=%d\n", seq_idx, block_index, block_offset, slot_idx);
32+
// }
33+
2934
#ifdef VALUE_CACHE_UPDATE
3035
const uint out_offset = CACHE_BLOCK_STRIDE * block_index +
3136
hidden_idx * KV_CACHE_BLOCK_SIZE +

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_ref.cl

+87-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#define Q_LOAD_ITERS (HEAD_SIZE / SUB_GROUP_SIZE)
1515

1616
// How much QK outputs each subgroup calculates per block
17-
#define QK_VALS_PER_SG_PER_ITER (BLOCK_SIZE / SUBGROUPS_PER_WG)
17+
#define QK_VALS_PER_SG_PER_ITER CEIL_DIV(BLOCK_SIZE, SUBGROUPS_PER_WG)
1818

1919
#define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * KV_HEADS_NUM * BLOCK_SIZE)
2020

@@ -35,6 +35,7 @@ KERNEL(pa_sdpa_ref)(
3535
const __global INPUT4_TYPE* context_lens,
3636
const __global INPUT5_TYPE* block_tables,
3737
const __global INPUT6_TYPE* scale,
38+
const __global INPUT7_TYPE* is_prompt,
3839
#ifdef USE_SEQ_LEN_SPLIT
3940
__global OUTPUT_TYPE* output,
4041
__global ACCUMULATOR_TYPE* exp_sums,
@@ -71,6 +72,10 @@ KERNEL(pa_sdpa_ref)(
7172

7273
const uint total_blocks_num = CEIL_DIV(context_len, BLOCK_SIZE);
7374

75+
// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
76+
// printf("context_len=%d block_start_idx=%d total_blocks_num=%d context_len=%d, SCALE_VAL=%f is_prompt=%d\n", context_len, block_start_idx, total_blocks_num, context_len, scale[0], is_prompt[0]);
77+
// }
78+
7479
__local OUTPUT_TYPE qk_vals_local[SHARED_MEM_SIZE];
7580
ACCUMULATOR_TYPE qk_max = ACCUMULATOR_VAL_MIN;
7681

@@ -99,7 +104,12 @@ KERNEL(pa_sdpa_ref)(
99104
for (uint q_idx = 0; q_idx < Q_LOAD_ITERS; q_idx++) {
100105
for (uint qk_idx = 0; qk_idx < QK_VALS_PER_SG_PER_ITER; qk_idx++) {
101106
uint current_token = (block_start_idx + block_num) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
107+
#if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
108+
// TODO: Optimize for BLOCK_SIZE % SUBGROUPS_PER_WG != 0 case
109+
if (current_token >= context_len || sgid >= BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER)
110+
#else
102111
if (current_token >= context_len)
112+
#endif
103113
continue;
104114

105115
const uint key_idx = block_offset +
@@ -120,27 +130,44 @@ KERNEL(pa_sdpa_ref)(
120130
}
121131
}
122132

133+
// if (context_len == 17 && sgid == 4 && QK_VALS_PER_SG_PER_ITER == 4 && (head_num_idx == 0 || head_num_idx == 1 || head_num_idx == 28)) {
134+
// printf("FROM SGID=4; token_idx=%d, head_num=%d block_num=%d, sglid=%d: %f %f %f %f \n", token_idx, head_num_idx, block_num, sglid,
135+
// qk[0], qk[1], qk[2], qk[3]);
136+
// }
137+
123138
// Summurize qk calculation across all WIs and apply scale
124139
for (uint qk_idx = 0; qk_idx < QK_VALS_PER_SG_PER_ITER; qk_idx++) {
125140
const uint current_token = (block_start_idx + block_num) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + qk_idx;
141+
#if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
142+
if (current_token < context_len && sgid < BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER) {
143+
#else
126144
if (current_token < context_len) {
145+
#endif
127146
qk[qk_idx] = sub_group_reduce_add(qk[qk_idx]);
128147

129148
// Apply scale
130149
qk[qk_idx] = scale[0] * qk[qk_idx];
131150

132151
// Apply attention mask for context processing stage
133-
const bool is_prefill_stage = INPUT0_FEATURE_NUM > 1;
134-
if (is_prefill_stage && current_token > token_idx) {
135-
qk[qk_idx] = qk[qk_idx] + OUTPUT_VAL_MIN;
152+
const unsigned char is_prefill_stage = is_prompt[0];
153+
if (is_prefill_stage == 1) {
154+
if (current_token > token_idx)
155+
qk[qk_idx] = qk[qk_idx] + OUTPUT_VAL_MIN;
156+
} else if (is_prefill_stage == 2) {
157+
if (current_token > context_len - INPUT0_FEATURE_NUM + token_idx)
158+
qk[qk_idx] = qk[qk_idx] + OUTPUT_VAL_MIN;
136159
}
137160

138161
qk_max = ACCUMULATOR_MAX_FUNC(qk_max, TO_ACCUMULATOR_TYPE(qk[qk_idx]));
139162
}
140163
}
141164

142165
// Save QK results to local memory
166+
#if BLOCK_SIZE % SUBGROUPS_PER_WG != 0
167+
if (sglid < QK_VALS_PER_SG_PER_ITER && sgid < BLOCK_SIZE / QK_VALS_PER_SG_PER_ITER) {
168+
#else
143169
if (sglid < QK_VALS_PER_SG_PER_ITER) {
170+
#endif
144171
const uint current_token_global_idx = (block_start_idx + block_num) * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid;
145172
#ifdef USE_SEQ_LEN_SPLIT
146173
const uint current_token_local = block_num * BLOCK_SIZE + sgid * QK_VALS_PER_SG_PER_ITER + sglid;
@@ -152,6 +179,33 @@ KERNEL(pa_sdpa_ref)(
152179
}
153180
}
154181

182+
// barrier(CLK_LOCAL_MEM_FENCE);
183+
// if (get_global_id(1) == 0 && get_global_id(2) == 0) {
184+
// if (context_len == 15)
185+
// printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
186+
// token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
187+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
188+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], is_prompt[0]);
189+
// else if (context_len == 16)
190+
// printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
191+
// token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
192+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
193+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], is_prompt[0]);
194+
// else if (context_len == 17)
195+
// printf("token_idx=%d, qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d\n",
196+
// token_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
197+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
198+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0]);
199+
// }
200+
201+
// barrier(CLK_LOCAL_MEM_FENCE);
202+
// if (context_len == 17 && sgid == 4 && sglid == 0) {
203+
// printf("FROM SGID=4; token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d. qk_max=%f\n",
204+
// token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
205+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
206+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0], qk_max);
207+
// }
208+
155209
// Apply SoftMax operation
156210
__local ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG];
157211
__local ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG];
@@ -168,6 +222,16 @@ KERNEL(pa_sdpa_ref)(
168222
// Final max value after reduction across of all SG and WI
169223
qk_max = sub_group_reduce_max(qk_max);
170224

225+
// barrier(CLK_LOCAL_MEM_FENCE);
226+
// if (context_len == 17 && get_global_id(2) == 0 && (head_num_idx == 1 || head_num_idx == 28) && SUBGROUPS_PER_WG == 5) {
227+
// printf("Calculation QK_VALS token_idx=%d, head_num=%d qk_vals_local: %f (-qk_max = %f, native_exp = %f), %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f(-qk_max = %f, native_exp = %f): %d. qk_max=%f (%f %f %f %f %f)\n",
228+
// token_idx, head_num_idx, qk_vals_local[0], TO_ACCUMULATOR_TYPE(qk_vals_local[0] - qk_max), native_exp(TO_ACCUMULATOR_TYPE(qk_vals_local[0]) - qk_max), qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
229+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
230+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15],
231+
// qk_vals_local[16], TO_ACCUMULATOR_TYPE(qk_vals_local[16] - qk_max), native_exp(TO_ACCUMULATOR_TYPE(qk_vals_local[16]) - qk_max),
232+
// is_prompt[0], qk_max, qk_max_vals[0], qk_max_vals[1], qk_max_vals[2], qk_max_vals[3], qk_max_vals[4]);
233+
// }
234+
171235
ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO;
172236
#ifdef USE_SEQ_LEN_SPLIT
173237
const uint qk_num = (num_of_portions == 1) ? CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE)
@@ -189,6 +253,15 @@ KERNEL(pa_sdpa_ref)(
189253
}
190254
}
191255

256+
257+
// barrier(CLK_LOCAL_MEM_FENCE);
258+
// if (context_len == 17 && get_global_id(2) == 0) {
259+
// printf("UPDATED QK_VALS token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f: %d. qk_max=%f\n",
260+
// token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
261+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
262+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], is_prompt[0], qk_max);
263+
// }
264+
192265
exp_sum = sub_group_reduce_add(exp_sum);
193266

194267
if (sglid == 0)
@@ -236,6 +309,16 @@ KERNEL(pa_sdpa_ref)(
236309
}
237310
}
238311
#endif
312+
313+
314+
// if (context_len == 17 && get_global_id(2) == 0 && SUBGROUPS_PER_WG == 5) {
315+
// printf("SF result: token_idx=%d, head_num=%d qk_vals_local: %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f; Total qk_max=%f total sum=%f (%f %f %f %f %f)\n",
316+
// token_idx, head_num_idx, qk_vals_local[0], qk_vals_local[1], qk_vals_local[2], qk_vals_local[3], qk_vals_local[4],
317+
// qk_vals_local[5], qk_vals_local[6], qk_vals_local[7], qk_vals_local[8], qk_vals_local[9],
318+
// qk_vals_local[10], qk_vals_local[11], qk_vals_local[12], qk_vals_local[13], qk_vals_local[14], qk_vals_local[15], qk_vals_local[16], qk_max, exp_sum,
319+
// qk_sum_vals[0], qk_sum_vals[1], qk_sum_vals[2], qk_sum_vals[3], qk_sum_vals[4]);
320+
321+
// }
239322
}
240323

241324
{

src/plugins/intel_gpu/src/kernel_selector/kernels/paged_attention/sdpa_kernel_ref.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace kernel_selector {
1212
// For kernel w/o split
1313
constexpr size_t max_sequence_length = 3072;
1414

15-
constexpr size_t seq_len_portion_size = 256;
15+
constexpr size_t seq_len_portion_size = 512;
1616
constexpr size_t subgroup_size = 16;
1717

1818
const Datatype softmax_acc_dt = Datatype::F32;
@@ -160,6 +160,7 @@ ParamsKey SDPAKernelRef::GetSupportedKey() const {
160160
ParamsKey key;
161161
key.EnableInputDataType(Datatype::F16);
162162
key.EnableInputDataType(Datatype::F32);
163+
key.EnableInputDataType(Datatype::UINT8);
163164
key.EnableInputDataType(Datatype::INT32);
164165
key.EnableOutputDataType(Datatype::F16);
165166
key.EnableOutputDataType(Datatype::F32);

0 commit comments

Comments
 (0)