Skip to content

Commit 2ef3ecf

Browse files
committed
WIP: FA impl
1 parent dbe1f69 commit 2ef3ecf

File tree

8 files changed

+321
-105
lines changed

8 files changed

+321
-105
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+213-34
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ bool SDPAKernelBase::Validate(const Params& p) const {
2121
return true;
2222
}
2323

24-
JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
25-
JitConstants jit = MakeBaseParamsJitConstants(params);
26-
27-
return jit;
28-
}
29-
3024
KernelsData SDPAKernelBase::GetCommonKernelsData(const Params& params) const {
3125
KernelData kd = KernelData::Default<sdpa_params>(params);
3226

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class SDPAKernelBase : public KernelBaseOpenCL {
4040

4141
protected:
4242
bool Validate(const Params&) const override;
43-
virtual JitConstants GetJitConstants(const sdpa_params& params) const;
4443
KernelsData GetCommonKernelsData(const Params& params) const;
4544
};
4645
} // namespace kernel_selector

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp

+97-61
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace kernel_selector {
1111

12-
constexpr size_t seq_len_partition_size = 256;
12+
constexpr size_t seq_len_partition_size = 32;
1313
constexpr size_t subgroup_size = 16;
1414

1515
ParamsKey SDPAKernelOpt::GetSupportedKey() const {
@@ -48,11 +48,12 @@ bool SDPAKernelOpt::Validate(const Params& p) const {
4848
return true;
4949
}
5050

51-
JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params) const {
52-
auto jit = Parent::GetJitConstants(params);
51+
JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const {
52+
auto jit = MakeBaseParamsJitConstants(params);
5353

54-
const auto softmax_acc_dt = params.inputs[0].GetDType();
55-
jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "ACCUMULATOR"));
54+
const auto softmax_acc_dt = Datatype::F32;
55+
// const auto softmax_acc_dt = params.inputs[0].GetDType();
56+
jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR"));
5657

5758
const auto& config = params.conf;
5859
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
@@ -63,6 +64,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params) const {
6364
jit.AddConstant(MakeJitConstant("USE_SEQ_LEN_SPLIT", 1));
6465
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
6566
jit.AddConstant(MakeJitConstant("SLM_SIZE", seq_len_partition_size));
67+
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(kernel_idx), 1));
6668

6769
return jit;
6870
}
@@ -74,18 +76,19 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
7476
const auto& key_input = params.inputs[1];
7577
const auto& output = params.outputs[0];
7678
if (!query_input.is_dynamic()) {
77-
const size_t seq_len = key_input.Y().v;
78-
const size_t num_of_partitions = CeilDiv(seq_len, seq_len_partition_size);
79+
const size_t source_seq_len = key_input.Y().v;
80+
const size_t target_seq_len = output.Y().v;
81+
const size_t num_of_partitions = CeilDiv(source_seq_len, seq_len_partition_size);
7982
const size_t head_size = static_cast<size_t>(params.conf.head_size);
8083

8184
if (kernel_idx == 0) {
8285
dispatch_data.gws = { output.Batch().v * output.Feature().v,
83-
output.Y().v,
86+
target_seq_len,
8487
head_size * num_of_partitions };
8588
dispatch_data.lws = { 1, 1, head_size };
8689
} else {
8790
dispatch_data.gws = { output.Batch().v * output.Feature().v,
88-
output.Y().v,
91+
target_seq_len,
8992
head_size };
9093
dispatch_data.lws = { 1, 1, subgroup_size };
9194
}
@@ -95,89 +98,122 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k
9598
}
9699

97100
KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const {
98-
KernelData kd = KernelData::Default<sdpa_params>(params);
99-
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
100-
101101
if (!Validate(params)) {
102102
return {};
103103
}
104104

105-
auto dispatchData = SetDefault(prim_params, 0);
106-
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params);
107-
auto cldnn_jit = GetJitConstants(prim_params);
108-
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
109-
110-
auto& kernel = kd.kernels[0];
105+
const size_t kernels_num = 2;
106+
KernelData kd = KernelData::Default<sdpa_params>(params, kernels_num);
107+
kd.needs_sub_kernels_sync = true;
111108

112109
GetUpdateDispatchDataFunc(kd);
113110

114-
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point,
115-
"", false, false, static_cast<int>(prim_params.inputs.size()),
116-
GetFusedPrimitiveInputsCount(params), 1, prim_params.is_shape_agnostic);
111+
const auto& prim_params = dynamic_cast<const sdpa_params&>(params);
112+
for (size_t kernel_num = 0; kernel_num < kernels_num; kernel_num++) {
113+
auto dispatch_data = SetDefault(prim_params, kernel_num);
114+
auto kernel_name = kernel_num == 0 ? kernelName : "sdpa_opt_finalization";
115+
auto entry_point = GetEntryPoint(kernel_name, prim_params.layerID, params);
116+
auto jit_constants = GetJitConstants(prim_params, kernel_num);
117+
auto jit = CreateJit(kernel_name, jit_constants, entry_point);
118+
119+
auto& kernel = kd.kernels[kernel_num];
120+
121+
auto inputs_num = kernel_num == 1 ? 0 : static_cast<int>(prim_params.inputs.size());
122+
FillCLKernelData(kernel,
123+
dispatch_data,
124+
params.engineInfo,
125+
kernelName,
126+
jit,
127+
entry_point,
128+
{},
129+
false,
130+
false,
131+
inputs_num,
132+
GetFusedPrimitiveInputsCount(params),
133+
static_cast<int>(prim_params.outputs.size()),
134+
prim_params.is_shape_agnostic);
135+
136+
const auto num_of_partitions = 1;
137+
auto& output = prim_params.outputs[0];
138+
auto head_size = output.X().v;
117139

118-
auto num_of_partitions = 1;
140+
auto buf_dt_size = 4;
141+
auto buf_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() / head_size * num_of_partitions;
142+
auto buf_size = buf_elements_count * buf_dt_size;
119143

120-
auto& output = prim_params.outputs[0];
121-
auto buf_dt_size = 4;
122-
// auto buf_elements_count = tokens_num * prim_params.configuration.heads_num * num_of_portions;
123-
auto buf_elements_count = output.LogicalSize() / output.X().v * num_of_partitions;
124-
auto buf_size = buf_elements_count * buf_dt_size;
144+
auto tmp_out_dt_size = 4;
145+
auto tmp_out_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() * num_of_partitions;
146+
auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
125147

126-
auto tmp_out_dt_size = 4;
127-
auto tmp_out_elements_count = output.LogicalSize() / output.X().v * num_of_partitions * prim_params.conf.head_size;
128-
auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
148+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
149+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
150+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
129151

130-
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
131-
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
132-
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
152+
kd.internalBufferSizes.clear();
153+
kd.internalBufferSizes.push_back(buf_size);
154+
kd.internalBufferSizes.push_back(buf_size);
155+
kd.internalBufferSizes.push_back(tmp_out_size);
156+
kd.internalBufferDataType = prim_params.inputs[0].GetDType();
133157

134-
kd.internalBufferSizes.clear();
135-
kd.internalBufferSizes.push_back(buf_size);
136-
kd.internalBufferSizes.push_back(buf_size);
137-
kd.internalBufferSizes.push_back(tmp_out_size);
138-
kd.internalBufferDataType = prim_params.inputs[0].GetDType();
158+
GPU_DEBUG_TRACE_DETAIL << "configure SDPA " << kernel_num << "th kernel: inputs_num=" << inputs_num << " arguments_num=" << kernel.params.arguments.size() << "\n";
139159

140-
// ScalarDescriptor num_of_partitions_scalar;
141-
// num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
142-
// num_of_partitions_scalar.v.u32 = 1;
160+
if (kernel_num == 1) {
161+
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});
143162

144-
// kd.kernels[1].params.scalars.resize(1);
145-
// kd.kernels[1].params.scalars[0] = num_of_partitions_scalar;
163+
ScalarDescriptor num_of_partitions_scalar;
164+
num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
165+
num_of_partitions_scalar.v.u32 = num_of_partitions;
166+
167+
kernel.params.scalars.clear();
168+
kernel.params.scalars.push_back(num_of_partitions_scalar);
169+
}
170+
}
146171

147172
return { kd };
148173
}
149174

150175
void SDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) const {
151176
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kernel_data) {
152177
const auto& prim_params = static_cast<const sdpa_params&>(params);
153-
auto dispatchData = SetDefault(prim_params, 0);
154-
OPENVINO_ASSERT(kernel_data.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
155-
kernel_data.kernels[0].params.workGroups.global = dispatchData.gws;
156-
kernel_data.kernels[0].params.workGroups.local = dispatchData.lws;
157-
kernel_data.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params);
158-
159-
// auto& in_q = prim_params.inputs[0];
160-
// auto& in_k = prim_params.inputs[1];
161178

162-
// ScalarDescriptor num_of_partitions_scalar;
163-
// num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
164-
// num_of_partitions_scalar.v.u32 = 1;
179+
const size_t expected_kernels_num = 2;
180+
OPENVINO_ASSERT(kernel_data.kernels.size() == expected_kernels_num,
181+
"[GPU] Invalid kernels size for update dispatch data func of SDPA kernel");
165182

166-
// kernel_data.kernels[0].params.scalars.resize(1);
167-
// kernel_data.kernels[0].params.scalars[0] = num_of_partitions_scalar;
183+
auto& output = prim_params.outputs[0];
184+
auto& key_input = prim_params.inputs[1];
168185

169-
auto num_of_partitions = 1;
186+
auto head_size = output.X().v;
187+
auto source_seq_len = key_input.Y().v;
188+
auto num_of_partitions = CeilDiv(source_seq_len, seq_len_partition_size);
170189

171-
auto& output = prim_params.outputs[0];
172190
auto buf_dt_size = 4;
173-
// auto buf_elements_count = tokens_num * prim_params.configuration.heads_num * num_of_portions;
174-
auto buf_elements_count = output.LogicalSize() / output.X().v * num_of_partitions;
191+
auto buf_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() / head_size * num_of_partitions;
175192
auto buf_size = buf_elements_count * buf_dt_size;
176193

177194
auto tmp_out_dt_size = 4;
178-
auto tmp_out_elements_count = output.LogicalSize() / output.X().v * num_of_partitions * prim_params.conf.head_size;
195+
auto tmp_out_elements_count = (num_of_partitions == 1) ? 1 : output.LogicalSize() * num_of_partitions;
179196
auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size;
180197

198+
auto dispatch_data1 = SetDefault(prim_params, 0);
199+
kernel_data.kernels[0].params.workGroups.global = dispatch_data1.gws;
200+
kernel_data.kernels[0].params.workGroups.local = dispatch_data1.lws;
201+
kernel_data.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params);
202+
203+
ScalarDescriptor num_of_partitions_scalar;
204+
num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32;
205+
num_of_partitions_scalar.v.u32 = num_of_partitions;
206+
207+
auto dispatch_data2 = SetDefault(prim_params, 1);
208+
kernel_data.kernels[1].params.workGroups.global = dispatch_data2.gws;
209+
kernel_data.kernels[1].params.workGroups.local = dispatch_data2.lws;
210+
kernel_data.kernels[1].skip_execution = num_of_partitions == 1;
211+
212+
kernel_data.kernels[1].params.scalars.clear();
213+
kernel_data.kernels[1].params.scalars.push_back(num_of_partitions_scalar);
214+
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 0th kernel: arguments_num=" << kernel_data.kernels[0].params.arguments.size() << "\n";
215+
GPU_DEBUG_TRACE_DETAIL << "update_dispatch_data_func SDPA 1th kernel: arguments_num=" << kernel_data.kernels[1].params.arguments.size() << "\n";
216+
181217
kernel_data.internalBufferSizes.clear();
182218
kernel_data.internalBufferSizes.push_back(buf_size);
183219
kernel_data.internalBufferSizes.push_back(buf_size);

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class SDPAKernelOpt : public SDPAKernelBase {
2121
bool Validate(const Params& p) const override;
2222
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
2323
CommonDispatchData SetDefault(const sdpa_params& params, size_t kernel_idx) const;
24-
JitConstants GetJitConstants(const sdpa_params& params) const override;
24+
JitConstants GetJitConstants(const sdpa_params& params, size_t kernel_idx) const;
2525
std::vector<FusedOpType> GetSupportedFusedOps() const override {
2626
return {};
2727
}

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ ParamsKey SDPAKernelRef::GetSupportedKey() const {
2929
}
3030

3131
JitConstants SDPAKernelRef::GetJitConstants(const sdpa_params& params) const {
32-
auto jit = Parent::GetJitConstants(params);
32+
auto jit = MakeBaseParamsJitConstants(params);
3333

3434
jit.Merge(MakeTypeJitConstants(Datatype::F16, "ACCUMULATOR"));
3535

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class SDPAKernelRef : public SDPAKernelBase {
2020
protected:
2121
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
2222
CommonDispatchData SetDefault(const sdpa_params& params) const;
23-
JitConstants GetJitConstants(const sdpa_params& params) const override;
23+
JitConstants GetJitConstants(const sdpa_params& params) const;
2424
std::vector<FusedOpType> GetSupportedFusedOps() const override {
2525
return {};
2626
}

src/plugins/intel_gpu/src/runtime/ocl/ocl_stream.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,15 @@ void set_arguments_impl(ocl_kernel_type& kernel,
7979
const kernel_arguments_data& data) {
8080
using args_t = argument_desc::Types;
8181
using scalar_t = scalar_desc::Types;
82+
GPU_DEBUG_TRACE_DETAIL << "Total args " << args.size() << "\n";
83+
GPU_DEBUG_TRACE_DETAIL << "data.inputs.size() = " << data.inputs.size() << "\n";
84+
GPU_DEBUG_TRACE_DETAIL << "data.intermediates.size() = " << data.intermediates.size() << "\n";
85+
GPU_DEBUG_TRACE_DETAIL << "data.outputs.size() = " << data.outputs.size() << "\n";
86+
if (data.scalars)
87+
GPU_DEBUG_TRACE_DETAIL << "data.scalars->size() = " << data.scalars->size() << "\n";
88+
GPU_DEBUG_TRACE_DETAIL << "data.shape_info = " << data.shape_info << "\n";
8289
for (uint32_t i = 0; i < static_cast<uint32_t>(args.size()); i++) {
90+
GPU_DEBUG_TRACE_DETAIL << "setting " << static_cast<size_t>(args[i].t) << " index=" << args[i].index << "\n";
8391
cl_int status = CL_INVALID_ARG_VALUE;
8492
switch (args[i].t) {
8593
case args_t::INPUT:

0 commit comments

Comments
 (0)