Skip to content

Commit 3ed1868

Browse files
committed
WIP: SDPA kernel
1 parent 15e3064 commit 3ed1868

18 files changed

+697
-133
lines changed

src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma once
66
#include "primitive.hpp"
7+
#include "intel_gpu/graph/program.hpp"
78

89
#include <vector>
910

@@ -32,5 +33,7 @@ struct paged_attention : public primitive_base<paged_attention> {
3233
void load(BinaryInputBuffer& ib) override {
3334
primitive_base<paged_attention>::load(ib);
3435
}
36+
37+
std::shared_ptr<cldnn::program> prefill_stage;
3538
};
3639
} // namespace cldnn

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+231-105
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
4242
typed_primitive_inst(network& network, const paged_attention_node& desc);
4343
typed_primitive_inst(network& network) : parent(network) {}
4444

45+
std::shared_ptr<network> prefill_network;
46+
4547
protected:
4648
void update_shape_info_tensor(const kernel_impl_params& params) override;
47-
48-
private:
49-
size_t paged_attention_id = 0;
5049
};
5150

5251
using paged_attention_inst = typed_primitive_inst<paged_attention>;

src/plugins/intel_gpu/src/graph/include/primitive_inst.h

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ class primitive_inst {
224224
void reset_output_change() { _output_changed = false; }
225225

226226
bool shape_changed() const { return _shape_changed; }
227+
void set_mem_changed(bool mem_changed) { _mem_changed = mem_changed; }
227228
bool mem_changed() const { return _mem_changed; }
228229
void reset_shape_change() { _shape_changed = false; }
229230
void set_shape_change() { _shape_changed = true; }

src/plugins/intel_gpu/src/graph/network.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
10541054
auto prog_id = ((get_program() != nullptr) ? get_program()->get_id() : 0);
10551055
auto net_id = get_id();
10561056
GPU_DEBUG_IF(debug_config->is_target_iteration(curr_iter) &&
1057-
debug_config->is_layer_for_dumping(layer_name, inst->is_output(), inst->is_input())) {
1057+
debug_config->is_layer_for_dumping(layer_name, inst->is_output(), inst->is_input()) && prog_id == 2) {
10581058
std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\""
10591059
+ layer_name + ":";
10601060
for (size_t i = 0; i < get_primitive(layer_name)->outputs_memory_count(); i++) {

src/plugins/intel_gpu/src/graph/paged_attention.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,9 @@ void paged_attention_inst::update_shape_info_tensor(const kernel_impl_params& pa
9898
}
9999

100100
paged_attention_inst::typed_primitive_inst(network& network, const paged_attention_node& node)
101-
: parent(network, node) {}
101+
: parent(network, node)
102+
, prefill_network(network::allocate_network(network.get_stream_ptr(),
103+
node.get_primitive()->prefill_stage,
104+
false,
105+
network.is_primary_stream())) { }
102106
} // namespace cldnn

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ void primitive_inst::check_memory_to_set(const memory& mem, const layout& layout
230230
}
231231

232232
event::ptr primitive_inst::set_output_memory(memory::ptr mem_new, bool check, size_t idx) {
233+
GPU_DEBUG_TRACE_DETAIL << "set_output memory for " << id() << ": " << mem_new << "\n";
233234
auto& eng = _network.get_engine();
234235
// skip all the buzz if no action actually required
235236
event::ptr ev = nullptr;
@@ -245,6 +246,7 @@ event::ptr primitive_inst::set_output_memory(memory::ptr mem_new, bool check, si
245246
if (is_constant()) {
246247
ev = mem_new->copy_from(_network.get_stream(), *_outputs[idx], false);
247248
} else {
249+
GPU_DEBUG_TRACE_DETAIL << "change output memory: " << mem_new << "\n";
248250
ev = get_network().get_stream().create_user_event(true);
249251
_outputs[idx] = mem_new;
250252
}

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/kv_cache_update_ref.cl src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "include/batch_headers/common.cl"
66

7-
KERNEL(kv_cache_update)(
7+
KERNEL(pa_kv_cache_update)(
88
OPTIONAL_SHAPE_INFO_ARG
99
__global const INPUT0_TYPE* key_data,
1010
__global const INPUT1_TYPE* value_data,
@@ -15,21 +15,21 @@ KERNEL(kv_cache_update)(
1515
{
1616
const uint batch_idx = (uint)get_global_id(0);
1717
const uint seq_idx = (uint)get_global_id(1);
18-
const uint hidden_idx = (uint)get_global_id(2); /* head_size */
18+
const uint head_elem_idx = (uint)get_global_id(2);
1919

20-
const uint in_offset = batch_idx * INPUT0_BATCH_PITCH + seq_idx * INPUT0_FEATURE_PITCH + hidden_idx;
20+
const uint in_offset = batch_idx * INPUT0_BATCH_PITCH + seq_idx * INPUT0_FEATURE_PITCH + head_elem_idx;
2121
const uint slot_offset = batch_idx * INPUT0_FEATURE_NUM + seq_idx;
2222

2323
const INPUT2_TYPE slot_idx = slot_mapping[slot_offset];
24-
if (hidden_idx >= INPUT0_FEATURE_PITCH || slot_idx == -1)
24+
if (head_elem_idx >= INPUT0_FEATURE_PITCH || slot_idx == -1)
2525
return;
2626

2727
const uint block_index = slot_idx / KV_CACHE_BLOCK_SIZE;
2828
const uint block_offset = slot_idx % KV_CACHE_BLOCK_SIZE;
2929

3030
#ifdef VALUE_CACHE_UPDATE
3131
const uint out_offset = block_elem_num * block_index +
32-
hidden_idx * KV_CACHE_BLOCK_SIZE +
32+
head_elem_idx * KV_CACHE_BLOCK_SIZE +
3333
block_offset;
3434

3535
// if (INPUT0_FEATURE_NUM == 18 && INPUT0_BATCH_NUM == 2) {
@@ -39,8 +39,8 @@ KERNEL(kv_cache_update)(
3939
value_cache_data[out_offset] = value_data[in_offset];
4040
#else
4141
#define HEAD_SIZE_BLOCKING 4
42-
const uint head_size_outer_block = hidden_idx / HEAD_SIZE_BLOCKING;
43-
const uint head_size_inner_block = hidden_idx % HEAD_SIZE_BLOCKING;
42+
const uint head_size_outer_block = head_elem_idx / HEAD_SIZE_BLOCKING;
43+
const uint head_size_inner_block = head_elem_idx % HEAD_SIZE_BLOCKING;
4444

4545
const uint out_offset = block_elem_num * block_index +
4646
block_offset * HEAD_SIZE_BLOCKING +
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "include/batch_headers/common.cl"
6+
#include "include/batch_headers/sub_group_block_read.cl"
7+
#include "include/batch_headers/sub_group_block_write.cl"
8+
#include "include/batch_headers/sub_group_shuffle.cl"
9+
10+
11+
12+
13+
14+
// constexpr size_t HEAD_SIZE = 64;
15+
// constexpr size_t HEADS_NUM = 32;
16+
// constexpr size_t KV_HEADS_NUM = 4;
17+
// constexpr size_t BLOCK_SIZE = 16;
18+
// constexpr size_t X_SIZE = 4;
19+
20+
// constexpr size_t MAX_SEQUENCE_LENGTH = 1024;
21+
22+
23+
24+
#define SUB_GROUP_SIZE 16
25+
26+
// The size of portion of HEAD_SIZE each WI process
27+
#define HEAD_ITEMS_PER_WI (HEAD_SIZE / SUB_GROUP_SIZE)
28+
29+
// How much QK outputs each subgroup calculates per cycle
30+
#define QK_PER_SG 4
31+
32+
#define KV_CACHE_BLOCK_STRIDE (HEAD_SIZE * HEADS_NUM * BLOCK_SIZE)
33+
34+
#define QUERY_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
35+
36+
#define SUBGROUPS_PER_WG HEAD_SIZE / SUB_GROUP_SIZE
37+
38+
REQD_SUB_GROUP_SIZE(SUB_GROUP_SIZE)
39+
__attribute__((reqd_work_group_size(1, 1, SUB_GROUP_SIZE)))
40+
KERNEL(pa_sdpa_ref)(
41+
OPTIONAL_SHAPE_INFO_ARG
42+
__global const INPUT0_TYPE* query,
43+
__global const INPUT1_TYPE* key_cache,
44+
__global const INPUT2_TYPE* value_cache,
45+
__global const INPUT3_TYPE* max_context_len,
46+
__global const INPUT4_TYPE* context_lens,
47+
__global const INPUT5_TYPE* block_tables,
48+
__global OUTPUT_TYPE* output)
49+
{
50+
const uint seq_idx = get_global_id(0);
51+
const uint head_num_idx = get_global_id(1);
52+
const uint head_idx = get_global_id(2);
53+
const uint sglid = get_sub_group_local_id();
54+
const uint sgid = get_sub_group_id();
55+
56+
const uint batch_idx = seq_idx / INPUT0_FEATURE_NUM;
57+
const uint token_idx = seq_idx % INPUT0_FEATURE_NUM;
58+
59+
const uint context_len = context_lens[batch_idx];
60+
61+
const uint blocks_num = INPUT5_FEATURE_NUM;
62+
63+
// sgid0: 0..3
64+
// sgid1: 4..7
65+
// sgid2: 8..11
66+
// sgid3: 12..15
67+
68+
// sgid0: 16..19
69+
// sgid1: 20..23
70+
// sgid2: 24..27
71+
// sgid3: 28..31
72+
73+
// TODO: Need to make blocks division more flexible. Current approach suggests
74+
// to have 4 SG per WG, where each SG process 4 QK outputs, so 16 in total per WG
75+
76+
__local OUTPUT_TYPE qk_vals[SHARED_MEM_SIZE];
77+
78+
OUTPUT_TYPE qk_max = OUTPUT_VAL_MIN;
79+
80+
for (uint block = 0; block < blocks_num; block++) {
81+
const uint block_idx = batch_idx * blocks_num + block;
82+
const uint block_offset = block_tables[block_idx] * KV_CACHE_BLOCK_STRIDE;
83+
84+
OUTPUT_TYPE qk[QK_PER_SG] = {0};
85+
86+
for (uint hs = 0; hs < HEAD_ITEMS_PER_WI; hs++) {
87+
const uint query_idx = seq_idx * HEAD_SIZE * HEADS_NUM + hs * SUB_GROUP_SIZE;
88+
89+
// TODO: can be preloaded outside HEAD_ITEMS_PER_WI loop - need to check perf
90+
INPUT0_TYPE q = QUERY_BLOCK_READ(query, query_idx);
91+
for (uint qk_idx = 0; qk_idx < QK_PER_SG; qk_idx++) {
92+
uint current_token = block * BLOCK_SIZE + sgid * QK_PER_SG + qk_idx;
93+
if (current_token >= context_len)
94+
continue;
95+
96+
const uint key_idx = block_offset +
97+
(X_SIZE * QK_PER_SG) * sgid +
98+
(HEAD_ITEMS_PER_WI * BLOCK_SIZE * X_SIZE) * hs +
99+
(sglid / X_SIZE) * X_SIZE * BLOCK_SIZE +
100+
(sglid % X_SIZE) + qk_idx * X_SIZE;
101+
// TODO1: try block loading and shuffling
102+
// TODO2: try to load k*4 times and then calculate
103+
// TODO3: try bigger X block
104+
INPUT1_TYPE k = key_cache[key_idx];
105+
106+
qk[qk_idx] = mad(q, k, qk[qk_idx]);
107+
}
108+
}
109+
110+
// Summurize qk calculation across all WIs
111+
for (uint qk_idx = 0; qk_idx < QK_PER_SG; qk_idx++) {
112+
qk[QK_PER_SG] = sub_group_reduce_add(qk[QK_PER_SG]);
113+
qk_max = OUTPUT_MAX_FUNC(qk_max, qk[QK_PER_SG]);
114+
}
115+
116+
// Save QK results to local memory
117+
if (sglid < QK_PER_SG) {
118+
const uint qk_local_idx = block * BLOCK_SIZE * sgid * QK_PER_SG + sglid;
119+
qk_vals[qk_local_idx] = qk[sglid];
120+
}
121+
}
122+
123+
/* WARNING NEED TO ADD BIAS BEFORE SOFTMAX */
124+
125+
// Apply SoftMax operation
126+
__local OUTPUT_TYPE qk_max_vals[SUBGROUPS_PER_WG];
127+
__local OUTPUT_TYPE qk_sum_vals[SUBGROUPS_PER_WG];
128+
{
129+
if (sglid == 0)
130+
qk_max_vals[sgid] = qk_max;
131+
132+
barrier(CLK_LOCAL_MEM_FENCE);
133+
134+
qk_max = OUTPUT_VAL_MIN;
135+
if (sglid < SUBGROUPS_PER_WG)
136+
qk_max = qk_max_vals[sglid];
137+
138+
// Final max value after reduction across of all SG and WI
139+
qk_max = sub_group_reduce_max(qk_max);
140+
141+
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO;
142+
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
143+
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
144+
if (data_idx < context_len) {
145+
OUTPUT_TYPE val = native_exp(qk_vals[data_idx] - qk_max);
146+
exp_sum += val;
147+
qk_vals[data_idx] = val;
148+
}
149+
}
150+
151+
exp_sum = sub_group_reduce_add(exp_sum);
152+
153+
if (sglid == 0)
154+
qk_sum_vals[sgid] = exp_sum;
155+
156+
barrier(CLK_LOCAL_MEM_FENCE);
157+
158+
if (sglid < SUBGROUPS_PER_WG)
159+
exp_sum = qk_sum_vals[sglid];
160+
161+
// Final sum of all values
162+
exp_sum = sub_group_reduce_add(exp_sum);
163+
164+
const OUTPUT_TYPE inv_sum = OUTPUT_VAL_ONE / exp_sum;
165+
166+
for (uint qk_idx = 0; qk_idx < CEIL_DIV(context_len, SUBGROUPS_PER_WG * SUB_GROUP_SIZE); qk_idx++) {
167+
const uint data_idx = qk_idx * (SUBGROUPS_PER_WG * SUB_GROUP_SIZE) + sgid * SUB_GROUP_SIZE + sglid;
168+
if (data_idx < context_len) {
169+
OUTPUT_TYPE val = qk_vals[data_idx] * inv_sum;
170+
qk_vals[data_idx] = val;
171+
}
172+
}
173+
174+
barrier(CLK_LOCAL_MEM_FENCE);
175+
}
176+
177+
output[seq_idx + sglid] = qk_vals[sglid % context_len];
178+
}

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/softmax_gpu_bf.cl

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ KERNEL (softmax_gpu_continuous_bfyx)(
4444
const uint power = CALC_POWER(workers_per_data_set);
4545
const uint items_num = data_set_size>>power;
4646
const uint leftovers = data_set_size-(items_num<<power);
47+
if (data_set_idx == 0 && in_data_set_idx == 0) {
48+
printf("Power=%d, items_num=%d, letfovers=%d, data_set_size=%d, sub_group_size=%d\n", power, items_num,
49+
leftovers, data_set_size, get_sub_group_size());
50+
}
4751
#endif
4852

4953
const uint data_set_offset = data_set_idx * data_set_size;

src/plugins/intel_gpu/src/kernel_selector/kernels/paged_attention/kv_cache_update_kernel_ref.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ static constexpr size_t kv_cache_block_size = 16;
1313

1414
void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const {
1515
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
16-
const auto& prim_params = dynamic_cast<const kv_cache_update_update_params&>(params);
16+
const auto& prim_params = dynamic_cast<const kv_cache_update_params&>(params);
1717
auto dispatchData = SetDefault(prim_params);
1818
OPENVINO_ASSERT(kd.kernels.size() == 2, "[GPU] Invalid kernels size for update dispatch data func");
1919
kd.kernels[0].params.workGroups.global = dispatchData.gws;
@@ -40,11 +40,11 @@ KernelsData KVCacheUpdateKernelRef::GetKernelsData(const Params& params) const {
4040
return {};
4141
}
4242

43-
KernelData kd = KernelData::Default<kv_cache_update_update_params>(params, 2);
43+
KernelData kd = KernelData::Default<kv_cache_update_params>(params, 2);
4444
kd.needs_sub_kernels_sync = false;
4545
GetUpdateDispatchDataFunc(kd);
4646

47-
const auto& kernel_params = static_cast<const kv_cache_update_update_params&>(params);
47+
const auto& kernel_params = static_cast<const kv_cache_update_params&>(params);
4848
for (size_t i = 0; i < 2; i++) {
4949
const auto kernel_stage = i == 0 ? KernelMode::value_cache_update : KernelMode::key_cache_update;
5050
const auto dispatch_data = SetDefault(kernel_params);
@@ -101,7 +101,7 @@ bool KVCacheUpdateKernelRef::Validate(const Params& params) const {
101101
return false;
102102
}
103103

104-
const auto& kernel_params = dynamic_cast<const kv_cache_update_update_params&>(params);
104+
const auto& kernel_params = dynamic_cast<const kv_cache_update_params&>(params);
105105
if (kernel_params.inputs.size() != 3)
106106
return false;
107107

@@ -114,7 +114,7 @@ bool KVCacheUpdateKernelRef::Validate(const Params& params) const {
114114
return true;
115115
}
116116

117-
JitConstants KVCacheUpdateKernelRef::GetJitConstants(const kv_cache_update_update_params& kernel_params, KernelMode mode) const {
117+
JitConstants KVCacheUpdateKernelRef::GetJitConstants(const kv_cache_update_params& kernel_params, KernelMode mode) const {
118118
JitConstants jit = MakeBaseParamsJitConstants(kernel_params);
119119

120120
if (mode == KernelMode::key_cache_update)
@@ -127,7 +127,7 @@ JitConstants KVCacheUpdateKernelRef::GetJitConstants(const kv_cache_update_updat
127127
return jit;
128128
}
129129

130-
CommonDispatchData KVCacheUpdateKernelRef::SetDefault(const kv_cache_update_update_params& kernel_params) {
130+
CommonDispatchData KVCacheUpdateKernelRef::SetDefault(const kv_cache_update_params& kernel_params) {
131131
CommonDispatchData dispatch_data;
132132

133133
const auto& input = kernel_params.inputs[0];
@@ -139,8 +139,8 @@ CommonDispatchData KVCacheUpdateKernelRef::SetDefault(const kv_cache_update_upda
139139
const size_t batch_size = input.Batch().v;
140140
const size_t seq_len = input.Feature().v;
141141
const size_t tokens_num = batch_size * seq_len;
142-
const size_t hidden_size = input.LogicalSize() / (tokens_num);
143-
dispatch_data.gws = {batch_size, seq_len, Align(hidden_size, 16)};
142+
const size_t head_size = input.LogicalSize() / (tokens_num);
143+
dispatch_data.gws = {batch_size, seq_len, Align(head_size, 16)};
144144
dispatch_data.lws = {1, 1, 16};
145145
}
146146

src/plugins/intel_gpu/src/kernel_selector/kernels/paged_attention/kv_cache_update_kernel_ref.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@ enum class KernelMode {
1313
value_cache_update
1414
};
1515

16-
struct kv_cache_update_update_params : base_params {
17-
kv_cache_update_update_params() : base_params(KernelType::PA_KV_CACHE_UPDATE) {}
16+
struct kv_cache_update_params : base_params {
17+
kv_cache_update_params() : base_params(KernelType::PA_KV_CACHE_UPDATE) {}
1818
};
1919

2020
class KVCacheUpdateKernelRef : public KernelBaseOpenCL {
2121
public:
22-
KVCacheUpdateKernelRef() : KernelBaseOpenCL{"kv_cache_update_ref"} {}
22+
KVCacheUpdateKernelRef() : KernelBaseOpenCL{"pa_kv_cache_update_ref"} {}
2323
KernelsData GetKernelsData(const Params& params) const override;
2424
ParamsKey GetSupportedKey() const override;
2525

2626
protected:
2727
bool Validate(const Params& params) const override;
28-
JitConstants GetJitConstants(const kv_cache_update_update_params& kernel_params, KernelMode mode) const;
29-
static CommonDispatchData SetDefault(const kv_cache_update_update_params& kernel_params);
28+
JitConstants GetJitConstants(const kv_cache_update_params& kernel_params, KernelMode mode) const;
29+
static CommonDispatchData SetDefault(const kv_cache_update_params& kernel_params);
3030
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
3131
};
3232

0 commit comments

Comments
 (0)