Skip to content

Commit a01335c

Browse files
committed
[GPU] Add support of per-head mask for sdpa_micro kernel
1 parent 7260cc0 commit a01335c

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

+3
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
187187
Q += QRY_OFF(b1, b0, 0, 0) + INPUT0_OFFSET;
188188
V += VAL_OFF(b1, (b0 / KV_GROUP_SIZE), 0, 0) + INPUT2_OFFSET;
189189
A += DST_OFF(b1, b0, 0, 0, 0);
190+
#if WITH_ATTN_MASK
191+
msk += MSK_OFF(b1 % MSK_D0, b0 % MSK_D1, 0, 0);
192+
#endif
190193

191194
__builtin_assume_aligned(K, K_ALIGN);
192195
__builtin_assume_aligned(Q, Q_ALIGN);

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp

+24-4
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,37 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
439439
auto convert_strides = [](std::string target_prefix, std::string source_prefix, const std::vector<int64_t> order) {
440440
JitConstants definitions({});
441441

442-
std::vector<std::string> target_definitions = {
442+
std::vector<std::string> target_stride_definitions = {
443443
target_prefix + "_S0",
444444
target_prefix + "_S1",
445445
target_prefix + "_S2",
446446
target_prefix + "_S3",
447447
};
448448

449-
std::vector<std::string> source_definitions = {
449+
std::vector<std::string> source_stride_definitions = {
450450
source_prefix + "_BATCH_PITCH",
451451
source_prefix + "_FEATURE_PITCH",
452452
source_prefix + "_Y_PITCH",
453453
source_prefix + "_X_PITCH",
454454
};
455455

456-
for (size_t i = 0; i < target_definitions.size(); i++) {
457-
definitions.AddConstant(MakeJitConstant(target_definitions[i], source_definitions[order[i]]));
456+
std::vector<std::string> target_size_definitions = {
457+
target_prefix + "_D0",
458+
target_prefix + "_D1",
459+
target_prefix + "_D2",
460+
target_prefix + "_D3",
461+
};
462+
463+
std::vector<std::string> source_size_definitions = {
464+
source_prefix + "_BATCH_NUM",
465+
source_prefix + "_FEATURE_NUM",
466+
source_prefix + "_SIZE_Y",
467+
source_prefix + "_SIZE_X",
468+
};
469+
470+
for (size_t i = 0; i < target_stride_definitions.size(); i++) {
471+
definitions.AddConstant(MakeJitConstant(target_stride_definitions[i], source_stride_definitions[order[i]]));
472+
definitions.AddConstant(MakeJitConstant(target_size_definitions[i], source_size_definitions[order[i]]));
458473
}
459474

460475
return definitions;
@@ -470,6 +485,11 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
470485
jit.Merge(unit_parameters("VAL"));
471486
jit.Merge(unit_parameters("DST"));
472487

488+
if (params.inputs.size() > 3) {
489+
jit.Merge(convert_strides("MSK", "INPUT3", {0, 1, 2, 3}));
490+
jit.Merge(unit_parameters("MSK"));
491+
}
492+
473493
return jit;
474494
}
475495

0 commit comments

Comments
 (0)