@@ -439,22 +439,37 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
439
439
auto convert_strides = [](std::string target_prefix, std::string source_prefix, const std::vector<int64_t > order) {
440
440
JitConstants definitions ({});
441
441
442
- std::vector<std::string> target_definitions = {
442
+ std::vector<std::string> target_stride_definitions = {
443
443
target_prefix + " _S0" ,
444
444
target_prefix + " _S1" ,
445
445
target_prefix + " _S2" ,
446
446
target_prefix + " _S3" ,
447
447
};
448
448
449
- std::vector<std::string> source_definitions = {
449
+ std::vector<std::string> source_stride_definitions = {
450
450
source_prefix + " _BATCH_PITCH" ,
451
451
source_prefix + " _FEATURE_PITCH" ,
452
452
source_prefix + " _Y_PITCH" ,
453
453
source_prefix + " _X_PITCH" ,
454
454
};
455
455
456
- for (size_t i = 0 ; i < target_definitions.size (); i++) {
457
- definitions.AddConstant (MakeJitConstant (target_definitions[i], source_definitions[order[i]]));
456
+ std::vector<std::string> target_size_definitions = {
457
+ target_prefix + " _D0" ,
458
+ target_prefix + " _D1" ,
459
+ target_prefix + " _D2" ,
460
+ target_prefix + " _D3" ,
461
+ };
462
+
463
+ std::vector<std::string> source_size_definitions = {
464
+ source_prefix + " _BATCH_NUM" ,
465
+ source_prefix + " _FEATURE_NUM" ,
466
+ source_prefix + " _SIZE_Y" ,
467
+ source_prefix + " _SIZE_X" ,
468
+ };
469
+
470
+ for (size_t i = 0 ; i < target_stride_definitions.size (); i++) {
471
+ definitions.AddConstant (MakeJitConstant (target_stride_definitions[i], source_stride_definitions[order[i]]));
472
+ definitions.AddConstant (MakeJitConstant (target_size_definitions[i], source_size_definitions[order[i]]));
458
473
}
459
474
460
475
return definitions;
@@ -470,6 +485,11 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
470
485
jit.Merge (unit_parameters (" VAL" ));
471
486
jit.Merge (unit_parameters (" DST" ));
472
487
488
+ if (params.inputs .size () > 3 ) {
489
+ jit.Merge (convert_strides (" MSK" , " INPUT3" , {0 , 1 , 2 , 3 }));
490
+ jit.Merge (unit_parameters (" MSK" ));
491
+ }
492
+
473
493
return jit;
474
494
}
475
495
0 commit comments