@@ -604,22 +604,37 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
604
604
auto convert_strides = [](std::string target_prefix, std::string source_prefix, const std::vector<int64_t > order) {
605
605
JitConstants definitions ({});
606
606
607
- std::vector<std::string> target_definitions = {
607
+ std::vector<std::string> target_stride_definitions = {
608
608
target_prefix + " _S0" ,
609
609
target_prefix + " _S1" ,
610
610
target_prefix + " _S2" ,
611
611
target_prefix + " _S3" ,
612
612
};
613
613
614
- std::vector<std::string> source_definitions = {
614
+ std::vector<std::string> source_stride_definitions = {
615
615
source_prefix + " _BATCH_PITCH" ,
616
616
source_prefix + " _FEATURE_PITCH" ,
617
617
source_prefix + " _Y_PITCH" ,
618
618
source_prefix + " _X_PITCH" ,
619
619
};
620
620
621
- for (size_t i = 0 ; i < target_definitions.size (); i++) {
622
- definitions.AddConstant (MakeJitConstant (target_definitions[i], source_definitions[order[i]]));
621
+ std::vector<std::string> target_size_definitions = {
622
+ target_prefix + " _D0" ,
623
+ target_prefix + " _D1" ,
624
+ target_prefix + " _D2" ,
625
+ target_prefix + " _D3" ,
626
+ };
627
+
628
+ std::vector<std::string> source_size_definitions = {
629
+ source_prefix + " _BATCH_NUM" ,
630
+ source_prefix + " _FEATURE_NUM" ,
631
+ source_prefix + " _SIZE_Y" ,
632
+ source_prefix + " _SIZE_X" ,
633
+ };
634
+
635
+ for (size_t i = 0 ; i < target_stride_definitions.size (); i++) {
636
+ definitions.AddConstant (MakeJitConstant (target_stride_definitions[i], source_stride_definitions[order[i]]));
637
+ definitions.AddConstant (MakeJitConstant (target_size_definitions[i], source_size_definitions[order[i]]));
623
638
}
624
639
625
640
return definitions;
@@ -635,6 +650,11 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
635
650
jit.Merge (unit_parameters (" VAL" ));
636
651
jit.Merge (unit_parameters (" DST" ));
637
652
653
+ if (params.inputs .size () > 3 ) {
654
+ jit.Merge (convert_strides (" MSK" , " INPUT3" , {0 , 1 , 2 , 3 }));
655
+ jit.Merge (unit_parameters (" MSK" ));
656
+ }
657
+
638
658
if (params.conf .is_kv_compressed ) {
639
659
jit.AddConstant (MakeJitConstant (" KEY_SCALE" , params.key_cache_comp_scale ));
640
660
jit.AddConstant (MakeJitConstant (" VAL_SCALE" , params.value_cache_comp_scale ));
0 commit comments