Skip to content

Commit dbe1f69

Browse files
committedApr 24, 2024
Initial SDPA opt version
1 parent e1ca2fa commit dbe1f69

11 files changed

+461
-39
lines changed
 

‎src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "primitive_base.hpp"
66
#include "scaled_dot_product_attention_inst.h"
77
#include "sdpa/sdpa_kernel_selector.h"
8-
#include "sdpa/sdpa_kernel_ref.h"
8+
#include "sdpa/sdpa_kernel_base.h"
99

1010
namespace cldnn {
1111
namespace ocl {
@@ -21,6 +21,16 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_p
2121
return make_unique<scaled_dot_product_attention_impl>(*this);
2222
}
2323

24+
static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) {
25+
kernel_selector::sdpa_configuration config;
26+
27+
const auto query_ps = impl_param.get_input_layout(0).get_partial_shape();
28+
if (query_ps[query_ps.size() - 1].is_static())
29+
config.head_size = query_ps[query_ps.size() - 1].get_length();
30+
31+
return config;
32+
}
33+
2434
static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic) {
2535
const auto& primitive = impl_param.typed_desc<scaled_dot_product_attention>();
2636
auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic);
@@ -31,11 +41,7 @@ struct scaled_dot_product_attention_impl : typed_primitive_impl_ocl<scaled_dot_p
3141
params.inputs[2] = convert_data_tensor(impl_param.get_input_layout(2));
3242
params.inputs[3] = convert_data_tensor(impl_param.get_input_layout(3));
3343

34-
// std::cout << impl_param.typed_desc<scaled_dot_product_attention>()->id << "in[0] " << impl_param.get_input_layout(0).to_short_string() << "\n";
35-
// std::cout << impl_param.typed_desc<scaled_dot_product_attention>()->id << "in[1] " << impl_param.get_input_layout(1).to_short_string() << "\n";
36-
// std::cout << impl_param.typed_desc<scaled_dot_product_attention>()->id << "in[2] " << impl_param.get_input_layout(2).to_short_string() << "\n";
37-
// std::cout << impl_param.typed_desc<scaled_dot_product_attention>()->id << "in[3] " << impl_param.get_input_layout(3).to_short_string() << "\n";
38-
// std::cout << impl_param.typed_desc<scaled_dot_product_attention>()->id << "out[0] " << impl_param.get_output_layout(0).to_short_string() << "\n";
44+
params.conf = get_sdpa_configuration(impl_param);
3945

4046
params.set_dynamic_shape_offsets();
4147

‎src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl

+8
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ KERNEL(gemm_ref)(
143143
ACCUMULATOR_TYPE val0 = TO_ACCUMULATOR_TYPE(input0[in0_idx]);
144144
ACCUMULATOR_TYPE val1 = TO_ACCUMULATOR_TYPE(input1[in1_idx]);
145145

146+
147+
// ACCUMULATOR_TYPE tmp_acc = acc;
146148
acc += val0 * val1;
149+
// if ((x < 2) && (y < 2) && get_global_id(2) == 0) {
150+
// printf("y=%d(%d). x=%d(%d). ki=%d. %f = %f * %f + %f (in0_idx=%d, in1_idx=%d), %d %d %d\n",
151+
// y, OUTPUT_SIZE_Y, x, OUTPUT_SIZE_X, ki, acc, val0, val1, tmp_acc,
152+
// in0_idx, in1_idx, get_global_id(0), get_global_id(1), get_global_id(2)
153+
// );
154+
// }
147155
}
148156

149157
acc = TO_ACCUMULATOR_TYPE(ALPHA) * acc;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "include/batch_headers/fetch_data.cl"
6+
#include "include/batch_headers/common.cl"
7+
#include "include/batch_headers/sub_group_block_read.cl"
8+
#include "include/batch_headers/sub_group_block_write.cl"
9+
#include "include/batch_headers/sub_group_shuffle.cl"
10+
11+
// query_input [batch, heads_num, q_len, head_size]
12+
// key_input [batch, kv_heads_num, kv_len, head_size]
13+
// value_input [batch, kv_heads_num, kv_len, head_size]
14+
// attn_mask [1, 1, q_len, kv_len]
15+
// output [batch, heads_num, q_len, head_size]
16+
// tmp_buf [batch, heads_num, q_len, kv_len]
17+
18+
#if OUTPUT_TYPE_SIZE == 4
19+
#define VLOAD(offset, ptr) CAT(vload, SUBGROUP_SIZE)(offset, ptr)
20+
#else
21+
#define VLOAD(offset, ptr) CAT(vload, SUBGROUP_SIZE)(offset, (__global ushort*)(ptr))
22+
#endif
23+
#define KEY_VEC_TYPE MAKE_VECTOR_TYPE(INPUT1_TYPE, SUBGROUP_SIZE)
24+
#define AS_VALUE_VEC(val) CAT(as_, KEY_VEC_TYPE)(val)
25+
26+
#define QUERY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
27+
#define VALUE_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT2_TYPE, 1, ptr, offset)
28+
29+
#define TOTAL_SEQ_LEN INPUT1_SIZE_Y
30+
31+
#define SUBGROUPS_PER_WG (HEAD_SIZE / SUBGROUP_SIZE)
32+
33+
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
34+
KERNEL(sdpa_opt)(
35+
OPTIONAL_SHAPE_INFO_ARG
36+
const __global INPUT0_TYPE* query_input,
37+
const __global INPUT1_TYPE* key_input,
38+
const __global INPUT2_TYPE* value_input,
39+
const __global INPUT3_TYPE* attn_mask,
40+
__global OUTPUT_TYPE* output,
41+
__global ACCUMULATOR_TYPE* exp_sums,
42+
__global ACCUMULATOR_TYPE* max_logits,
43+
__global OUTPUT_TYPE* tmp_out
44+
)
45+
{
46+
uint dim0 = get_global_id(0);
47+
uint batch_idx = dim0 / INPUT0_FEATURE_NUM;
48+
uint head_num_idx = dim0 % INPUT0_FEATURE_NUM;
49+
uint seq_idx = get_global_id(1);
50+
uint head_size_idx = get_global_id(2);
51+
52+
const uint lid = get_local_id(2);
53+
const uint sgid = get_sub_group_id();
54+
const uint sglid = get_sub_group_local_id();
55+
56+
const uint partition_id = get_group_id(2);
57+
const uint num_of_partitions = get_num_groups(2);
58+
const uint wi_num_per_partition = get_local_size(2);
59+
60+
const uint partition_seq_len =
61+
((partition_id + 1) < num_of_partitions) ? (SEQ_LEN_PARTITION_SIZE)
62+
: (TOTAL_SEQ_LEN % SEQ_LEN_PARTITION_SIZE);
63+
64+
__local OUTPUT_TYPE qk_vals_local[SLM_SIZE];
65+
ACCUMULATOR_TYPE qk_max = ACCUMULATOR_VAL_MIN;
66+
67+
#ifndef INPUT4_TYPE
68+
const OUTPUT_TYPE scale_val = OUTPUT_VAL_ONE / sqrt(TO_OUTPUT_TYPE(HEAD_SIZE));
69+
#endif
70+
71+
/* Calculate Gemm1 */
72+
for (uint seq_len = lid; seq_len < partition_seq_len; seq_len += wi_num_per_partition) {
73+
uint query_offset = INPUT0_GET_INDEX(batch_idx, head_num_idx, seq_idx, 0);
74+
uint key_offset = INPUT1_GET_INDEX(batch_idx, head_num_idx, /* TODO: start_partition_idx + seq_len */ seq_len, 0);
75+
76+
INPUT0_TYPE acc = INPUT0_VAL_ZERO;
77+
unroll_for (uint h = 0; h < HEAD_SIZE; h += SUBGROUP_SIZE) {
78+
INPUT0_TYPE query_val = QUERY_BLOCK_READ(query_input, query_offset);
79+
KEY_VEC_TYPE key_vec = AS_VALUE_VEC(VLOAD(0, key_input + key_offset));
80+
81+
unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) {
82+
acc = mad(sub_group_broadcast(query_val, i), key_vec[i], acc);
83+
}
84+
85+
query_offset += SUBGROUP_SIZE;
86+
key_offset += SUBGROUP_SIZE;
87+
}
88+
89+
// Apply scale
90+
acc *= scale_val;
91+
92+
// Apply attention mask
93+
uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(batch_idx, head_num_idx, seq_idx, /* TODO: start_partition_idx + seq_len */ seq_len);
94+
acc += attn_mask[attn_mask_offset];
95+
96+
// Update qk_max value
97+
qk_max = ACCUMULATOR_MAX_FUNC(qk_max, TO_ACCUMULATOR_TYPE(acc));
98+
99+
qk_vals_local[seq_len] = acc;
100+
}
101+
102+
/* Apply SoftMax */
103+
__local ACCUMULATOR_TYPE qk_max_vals[SUBGROUPS_PER_WG];
104+
__local ACCUMULATOR_TYPE qk_sum_vals[SUBGROUPS_PER_WG];
105+
{
106+
// Find the maximum value of qk in the subgroup
107+
qk_max = sub_group_reduce_max(qk_max);
108+
109+
// Find the maximum value of qk across all subgroups in the workgroup
110+
if (sglid == 0)
111+
qk_max_vals[sgid] = qk_max;
112+
113+
barrier(CLK_LOCAL_MEM_FENCE);
114+
115+
qk_max = ACCUMULATOR_VAL_MIN;
116+
if (sglid < SUBGROUPS_PER_WG)
117+
qk_max = qk_max_vals[sglid];
118+
119+
// Final maximum value of qk after reduction across all subgroups
120+
qk_max = sub_group_reduce_max(qk_max);
121+
122+
ACCUMULATOR_TYPE exp_sum = ACCUMULATOR_VAL_ZERO;
123+
const uint qk_num_per_wi = CEIL_DIV(partition_seq_len, SUBGROUPS_PER_WG * SUBGROUP_SIZE);
124+
for (uint qk_idx = 0; qk_idx < qk_num_per_wi; qk_idx++) {
125+
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;
126+
if (local_data_idx < partition_seq_len) {
127+
ACCUMULATOR_TYPE qk_new = native_exp(TO_ACCUMULATOR_TYPE(qk_vals_local[local_data_idx]) - qk_max);
128+
qk_vals_local[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
129+
130+
exp_sum += qk_new;
131+
}
132+
}
133+
134+
exp_sum = sub_group_reduce_add(exp_sum);
135+
136+
if (sglid == 0)
137+
qk_sum_vals[sgid] = exp_sum;
138+
139+
barrier(CLK_LOCAL_MEM_FENCE);
140+
141+
exp_sum = ACCUMULATOR_VAL_ZERO;
142+
143+
if (sglid < SUBGROUPS_PER_WG)
144+
exp_sum = qk_sum_vals[sglid];
145+
146+
// Find the final sum of all exp_sum values in workgroup
147+
exp_sum = sub_group_reduce_add(exp_sum);
148+
149+
const ACCUMULATOR_TYPE inv_sum = ACCUMULATOR_VAL_ONE / exp_sum;
150+
for (uint qk_idx = 0; qk_idx < qk_num_per_wi; qk_idx++) {
151+
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;
152+
if (local_data_idx < partition_seq_len) {
153+
ACCUMULATOR_TYPE qk_new = TO_ACCUMULATOR_TYPE(qk_vals_local[local_data_idx]) * inv_sum;
154+
qk_vals_local[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
155+
}
156+
}
157+
158+
barrier(CLK_LOCAL_MEM_FENCE);
159+
160+
{
161+
// Save temporary exm_sums and max_logits values for each portion
162+
if (num_of_partitions > 1 && sgid == 0) {
163+
const uint exp_sums_offset = seq_idx * HEADS_NUM * num_of_partitions +
164+
head_num_idx * num_of_partitions +
165+
partition_id;
166+
exp_sums[exp_sums_offset] = exp_sum;
167+
168+
const uint max_logits_offset = exp_sums_offset;
169+
max_logits[max_logits_offset] = qk_max;
170+
}
171+
}
172+
}
173+
174+
/* Calculate Gemm2 */
175+
{
176+
OUTPUT_TYPE acc = OUTPUT_VAL_ZERO;
177+
for (uint seq_len = 0; seq_len < partition_seq_len; seq_len++) {
178+
const uint value_offset = INPUT1_GET_INDEX(batch_idx, head_num_idx, /* TODO: start_partition_idx + seq_len */ seq_len, head_size_idx);
179+
180+
/* Load seq_len / 16 + sglid */
181+
OUTPUT_TYPE qk_val = qk_vals_local[seq_len];
182+
INPUT2_TYPE value_val = VALUE_BLOCK_READ(value_input, value_offset);
183+
184+
acc = mad(qk_val, value_val, acc);
185+
}
186+
187+
if (num_of_partitions > 1) {
188+
const uint tmp_out_offset = seq_idx * (HEADS_NUM * HEAD_SIZE * num_of_partitions) +
189+
head_num_idx * (HEAD_SIZE * num_of_partitions) +
190+
partition_id * HEAD_SIZE +
191+
sgid * SUBGROUP_SIZE +
192+
sglid;
193+
194+
// tmp_output data layout [num_seqs, num_heads, num_portions, head_size]
195+
tmp_out[tmp_out_offset] = acc;
196+
} else {
197+
const uint output_offset = OUTPUT_GET_INDEX(batch_idx, head_num_idx, seq_idx, head_size_idx);
198+
199+
output[output_offset] = acc;
200+
}
201+
}
202+
}

‎src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_ref.cl

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ KERNEL(sdpa_ref)(
3737
const OUTPUT_TYPE scale = OUTPUT_VAL_ONE / sqrt(TO_OUTPUT_TYPE(INPUT1_SIZE_X));
3838
#endif
3939

40+
// Process 1*seq_len elements (Gemm1 + SoftMax) using a single work item, saving results to tmp_buf and
41+
// reusing them between all work items within a single workgroup for Gemm2 calculations.
4042
if (get_local_id(2) == 0) {
4143
for (uint s = 0; s < INPUT1_SIZE_Y /* seq_len */; s++) {
4244
OUTPUT_TYPE acc = 0;

‎src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp

-17
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,6 @@ bool SDPAKernelBase::Validate(const Params& p) const {
2121
return true;
2222
}
2323

24-
CommonDispatchData SDPAKernelBase::SetDefault(const sdpa_params& params) const {
25-
CommonDispatchData dispatchData;
26-
27-
auto in_layout = params.inputs[0].GetLayout();
28-
auto out_layout = params.outputs[0].GetLayout();
29-
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws = {{ Tensor::DataChannelName::BATCH },
30-
{ Tensor::DataChannelName::FEATURE },
31-
{ Tensor::DataChannelName::X, Tensor::DataChannelName::Y,
32-
Tensor::DataChannelName::Z, Tensor::DataChannelName::W }};
33-
34-
const auto& output = params.outputs[0];
35-
dispatchData.gws = { output.Batch().v, output.Feature().v, output.W().v * output.Z().v * output.Y().v * output.X().v };
36-
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, in_layout, out_layout, dims_by_gws);
37-
38-
return dispatchData;
39-
}
40-
4124
JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
4225
JitConstants jit = MakeBaseParamsJitConstants(params);
4326

‎src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

+7-15
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,19 @@
99
#include <vector>
1010

1111
namespace kernel_selector {
12+
struct sdpa_configuration {
13+
int64_t head_size = -1;
14+
int64_t heads_num = -1;
15+
int64_t kv_heads_num = -1;
16+
};
17+
1218
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
1319
// sdpa_params
1420
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
1521
struct sdpa_params : public base_params {
1622
sdpa_params() : base_params(KernelType::SDPA) {}
17-
DimTensor<uint32_t> block_shape;
18-
DimTensor<uint32_t> crops_begin;
19-
DimTensor<uint32_t> crops_end;
20-
21-
base_params::ArgType block_type = base_params::ArgType::Input;
22-
base_params::ArgType begin_type = base_params::ArgType::Input;
23-
base_params::ArgType end_type = base_params::ArgType::Input;
24-
25-
size_t block_dims = 0;
26-
size_t begin_dims = 0;
27-
size_t end_dims = 0;
2823

29-
size_t block_input_index = 0;
30-
size_t begin_input_index = 0;
31-
size_t end_input_index = 0;
24+
sdpa_configuration conf;
3225
};
3326

3427
struct sdpa_fuse_params : fuse_params {
@@ -48,7 +41,6 @@ class SDPAKernelBase : public KernelBaseOpenCL {
4841
protected:
4942
bool Validate(const Params&) const override;
5043
virtual JitConstants GetJitConstants(const sdpa_params& params) const;
51-
virtual CommonDispatchData SetDefault(const sdpa_params& params) const;
5244
KernelsData GetCommonKernelsData(const Params& params) const;
5345
};
5446
} // namespace kernel_selector

0 commit comments

Comments
 (0)