@@ -508,7 +508,9 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
508
508
print_arr (q_dims, q_dims.size (), " q_dims" );
509
509
print_arr (k_dims, k_dims.size (), " k_dims" );
510
510
print_arr (v_dims, v_dims.size (), " v_dims" );
511
- std::cout << Q_num_heads_dim.is_dynamic << " " << K_num_heads_dim.is_dynamic << " " << V_num_heads_dim.is_dynamic << " " << K_num_heads_dim.v << " " << V_num_heads_dim.v << " \n " ;
511
+ std::cout << Q_num_heads_dim.is_dynamic << " "
512
+ << K_num_heads_dim.is_dynamic << " "
513
+ << V_num_heads_dim.is_dynamic << " " << K_num_heads_dim.v << " " << V_num_heads_dim.v << " \n " ;
512
514
return false ;
513
515
}
514
516
@@ -555,6 +557,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
555
557
jit.AddConstant (MakeJitConstant (" SUBGROUP_SIZE" , subgroup_size (prim_params.engineInfo .arch )));
556
558
jit.AddConstant (MakeJitConstant (" INVERT_SCALE" , false ));
557
559
jit.AddConstant (MakeJitConstant (" SCALE_DATA_T" , " half" ));
560
+ jit.AddConstant (MakeJitConstant (" HEAD_SIZE" , head_size));
558
561
559
562
jit.AddConstant (MakeJitConstant (" WITH_ATTN_MASK" , sdpa_inputs > 3 ));
560
563
jit.AddConstant (MakeJitConstant (" WITH_SCALE" , sdpa_inputs > 4 ));
@@ -733,6 +736,8 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params, const
733
736
auto seq_length = get_seq_length (params.inputs [0 ], params.input0_order ).v ;
734
737
if (params.conf .is_paged_attention ) {
735
738
seq_length = params.conf .paged_attention_aligned_seq_len ;
739
+ GPU_DEBUG_TRACE_DETAIL << " seq_len=" << seq_length << " \n " ;
740
+ GPU_DEBUG_TRACE_DETAIL << " wg_tile_q=" << wg_tile_q << " \n " ;
736
741
}
737
742
738
743
dispatch_data.gws [0 ] *= CeilDiv (seq_length, wg_tile_q);
@@ -767,6 +772,14 @@ clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, bool is
767
772
kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INPUT, 2 }); // V
768
773
kernel.params .arguments .push_back ({ArgumentDescriptor::Types::OUTPUT, 0 }); // A
769
774
775
+
776
+ if (params.conf .is_paged_attention ) {
777
+ kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INPUT, 3 }); // subsequence_begins
778
+
779
+ kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0 });
780
+ kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1 });
781
+ kernel.params .arguments .push_back ({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2 });
782
+ }
770
783
// if (params.inputs.size() >= 4)
771
784
// kernel.params.arguments.push_back({ArgumentDescriptor::Types::INPUT, 3}); // mask
772
785
// if (params.inputs.size() >= 5)
0 commit comments