|
| 1 | +// Copyright (C) 2023-2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "intel_gpu/plugin/multi_tensor_variable_state.hpp" |
| 6 | +#include "intel_gpu/plugin/variable_state.hpp" |
| 7 | +#include "intel_gpu/runtime/debug_configuration.hpp" |
| 8 | +#include "intel_gpu/runtime/memory.hpp" |
| 9 | +#include "multi_stage_primitive.hpp" |
| 10 | + |
| 11 | +#include "paged_attention_inst.h" |
| 12 | +#include "paged_attention/paged_attention_kernel_selector.hpp" |
| 13 | +#include "paged_attention/kv_cache_update_kernel_ref.hpp" |
| 14 | +#include "paged_attention/sdpa_kernel_ref.hpp" |
| 15 | + |
| 16 | +namespace cldnn { |
| 17 | +namespace ocl { |
| 18 | + |
| 19 | +struct paged_attention_impl : multi_stage_primitive<paged_attention> { |
| 20 | + using parent = multi_stage_primitive<paged_attention>; |
| 21 | + using parent::parent; |
| 22 | + using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector; |
| 23 | + using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_params; |
| 24 | + |
| 25 | + using sdpa_kernel_selector_t = kernel_selector::sdpa_kernel_selector; |
| 26 | + using sdpa_kernel_params_t = kernel_selector::sdpa_params; |
| 27 | + |
| 28 | + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::paged_attention_impl) |
| 29 | + |
| 30 | + std::unique_ptr<primitive_impl> clone() const override { |
| 31 | + return make_unique<paged_attention_impl>(*this); |
| 32 | + } |
| 33 | + |
| 34 | + paged_attention_impl() = default; |
| 35 | + |
| 36 | + paged_attention_impl(const std::vector<kernel_selector::kernel_data>& kd) : parent(kd) { |
| 37 | + this->can_reuse_memory = true; |
| 38 | + } |
| 39 | + |
| 40 | + void set_arguments_impl(paged_attention_inst& instance) override {} |
| 41 | + kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage) const override { return kernel_arguments_data(); } |
| 42 | + |
| 43 | + enum Stage { |
| 44 | + KV_CACHE_UPDATE, |
| 45 | + SDPA |
| 46 | + }; |
| 47 | + |
| 48 | + void load(BinaryInputBuffer& ib) override { |
| 49 | + parent::load(ib); |
| 50 | + if (is_dynamic()) { |
| 51 | + auto& kernel_selector = kv_cache_update_kernel_selector_t::Instance(); |
| 52 | + auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName); |
| 53 | + kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::KV_CACHE_UPDATE]); |
| 54 | + |
| 55 | + auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance(); |
| 56 | + auto bt_kernel_impl = sdpa_kernel_selector.GetImplementation(_kernels_data[Stage::SDPA].kernelName); |
| 57 | + bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::SDPA]); |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx) const { |
| 62 | + kernel_arguments_data args; |
| 63 | + if (stage == Stage::KV_CACHE_UPDATE || (stage == Stage::SDPA && kernel_idx == 0)) |
| 64 | + args.shape_info = instance.shape_info_memory_ptr(); |
| 65 | + |
| 66 | + if (stage == Stage::KV_CACHE_UPDATE) { |
| 67 | + args.inputs = { instance.input_memory_ptr(1), /* key */ |
| 68 | + instance.input_memory_ptr(2), /* value */ |
| 69 | + instance.input_memory_ptr(6) /* slot_mapping */}; |
| 70 | + args.outputs = { instance.input_memory_ptr(3), /* key_cache */ |
| 71 | + instance.input_memory_ptr(4) /* value_cache */ }; |
| 72 | + } else if (stage == Stage::SDPA) { |
| 73 | + if (kernel_idx == 0) { |
| 74 | + args.inputs = { instance.input_memory_ptr(0), /* query */ |
| 75 | + instance.input_memory_ptr(3), /* key_cache */ |
| 76 | + instance.input_memory_ptr(4), /* value_cache */ |
| 77 | + instance.input_memory_ptr(7), /* max_context_len */ |
| 78 | + instance.input_memory_ptr(8), /* context_lens */ |
| 79 | + instance.input_memory_ptr(9), /* block_tables */ |
| 80 | + instance.input_memory_ptr(10) /* scale */ }; |
| 81 | + } else { |
| 82 | + args.inputs = { instance.input_memory_ptr(8), /* context_lens */ }; |
| 83 | + } |
| 84 | + args.outputs = { instance.output_memory_ptr(0) }; |
| 85 | + } |
| 86 | + |
| 87 | + return args; |
| 88 | + } |
| 89 | + |
| 90 | + void execute_stage(const std::vector<event::ptr>& events, paged_attention_inst& instance, std::vector<event::ptr>& all_events, size_t stage) { |
| 91 | + stream& stream = instance.get_network().get_stream(); |
| 92 | + std::vector<event::ptr> tmp_events(events); |
| 93 | + size_t kernel_offset = 0; |
| 94 | + for (size_t s = 0; s < stage; s++) { |
| 95 | + kernel_offset += _kernels_data[s].kernels.size(); |
| 96 | + } |
| 97 | + for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) { |
| 98 | + auto time0 = std::chrono::high_resolution_clock::now(); |
| 99 | + if (_kernels_data[stage].kernels[kd_idx].skip_execution) |
| 100 | + continue; |
| 101 | + |
| 102 | + size_t idx_final = kernel_offset + kd_idx; |
| 103 | + // If any user of the prim's users is CPU implementation or network's output, set prim as a output event (event won't be nullptr) |
| 104 | + bool needs_completion_event = instance.needs_completion_event(); |
| 105 | + |
| 106 | + auto& params = _kernels_data[stage].kernels[kd_idx].params; |
| 107 | + |
| 108 | + auto args = get_arguments(instance, stage, kd_idx); |
| 109 | + args.scalars = ¶ms.scalars; |
| 110 | + |
| 111 | + for (const auto& m : instance.get_intermediates_memories()) { |
| 112 | + args.intermediates.push_back(m); |
| 113 | + } |
| 114 | + |
| 115 | + auto time1 = std::chrono::high_resolution_clock::now(); |
| 116 | + stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args); |
| 117 | + auto time2 = std::chrono::high_resolution_clock::now(); |
| 118 | + |
| 119 | + const auto& gws = params.workGroups.global; |
| 120 | + const auto& lws = params.workGroups.local; |
| 121 | + |
| 122 | + GPU_DEBUG_TRACE_DETAIL << "Enqueue stage " << stage << " kernel " << idx_final << ": gws=[" << gws[0] << ", " << gws[1] << ", " << gws[2] << "] " |
| 123 | + << "lws=[" << lws[0] << ", " << lws[1] << ", " << lws[2] << "]" |
| 124 | + << (needs_completion_event ? " has_completion_event=true" : "") << std::endl; |
| 125 | + |
| 126 | + auto ev = stream.enqueue_kernel(*_kernels[idx_final], params, args, tmp_events, needs_completion_event); |
| 127 | + auto time3 = std::chrono::high_resolution_clock::now(); |
| 128 | + if (_kernels_data[stage].needs_sub_kernels_sync) { |
| 129 | + tmp_events = {ev}; |
| 130 | + } |
| 131 | + |
| 132 | + auto time_res0 = std::chrono::duration_cast<std::chrono::microseconds>(time1 - time0).count(); |
| 133 | + auto time_res1 = std::chrono::duration_cast<std::chrono::microseconds>(time2 - time1).count(); |
| 134 | + auto time_res2 = std::chrono::duration_cast<std::chrono::microseconds>(time3 - time2).count(); |
| 135 | + GPU_DEBUG_TRACE_DETAIL << "Time execute_stage inside = " << time_res0 << " " << time_res1 << " " << time_res2 << "\n"; |
| 136 | + |
| 137 | + all_events.push_back(ev); |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override { |
| 142 | + std::vector<event::ptr> res_events; |
| 143 | + execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE); |
| 144 | + |
| 145 | + std::vector<event::ptr> dep_events(res_events.begin(), res_events.end()); |
| 146 | + execute_stage(dep_events, instance, res_events, Stage::SDPA); |
| 147 | + |
| 148 | + return aggregate_events(res_events, instance.get_network().get_stream(), res_events.size() > 1); |
| 149 | + } |
| 150 | + |
| 151 | + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { |
| 152 | + kernel_selector::sdpa_configuration config; |
| 153 | + |
| 154 | + const auto query_layout = impl_param.get_input_layout(0); |
| 155 | + const auto key_cache_layout = impl_param.get_input_layout(3); |
| 156 | + const auto value_cache_layout = impl_param.get_input_layout(4); |
| 157 | + |
| 158 | + const auto desc = impl_param.typed_desc<paged_attention>(); |
| 159 | + config.head_size = desc->head_size; |
| 160 | + config.heads_num = desc->heads_num; |
| 161 | + config.kv_heads_num = desc->kv_heads_num; |
| 162 | + config.block_size = desc->block_size; |
| 163 | + config.x_block_size = desc->x_block_size; |
| 164 | + config.max_context_len = 1; |
| 165 | + |
| 166 | + const size_t simd_size = 16; |
| 167 | + OPENVINO_ASSERT(config.head_size % simd_size == 0, "[GPU] Head size is expected to be divisible by 16"); |
| 168 | + |
| 169 | + return config; |
| 170 | + } |
| 171 | + |
| 172 | + static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) { |
| 173 | + kv_cache_update_kernel_params_t params; |
| 174 | + set_params(impl_param, params); |
| 175 | + |
| 176 | + auto query = impl_param.get_input_layout(0); |
| 177 | + auto key = impl_param.get_input_layout(1); |
| 178 | + auto value = impl_param.get_input_layout(2); |
| 179 | + auto key_cache = impl_param.get_input_layout(3); |
| 180 | + auto value_cache = impl_param.get_input_layout(4); |
| 181 | + auto slot_mapping = impl_param.get_input_layout(6); |
| 182 | + |
| 183 | + params.is_shape_agnostic = is_dynamic; |
| 184 | + params.stage_id = 0; |
| 185 | + params.inputs.resize(3); |
| 186 | + params.outputs.resize(2); |
| 187 | + params.inputs[0] = convert_data_tensor(key); |
| 188 | + params.inputs[1] = convert_data_tensor(value); |
| 189 | + params.inputs[2] = convert_data_tensor(slot_mapping); |
| 190 | + params.outputs[0] = convert_data_tensor(key_cache); |
| 191 | + params.outputs[1] = convert_data_tensor(value_cache); |
| 192 | + params.layerID = impl_param.desc->id; |
| 193 | + |
| 194 | + params.configuration = get_sdpa_configuration(impl_param); |
| 195 | + |
| 196 | + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; |
| 197 | + std::map<size_t, size_t> in_tensor_to_offset_map = { |
| 198 | + {0, in_offsets_map.at(1)}, |
| 199 | + {1, in_offsets_map.at(2)}, |
| 200 | + {2, in_offsets_map.at(6)}, |
| 201 | + }; |
| 202 | + std::map<size_t, size_t> out_tensor_to_offset_map = { |
| 203 | + {0, in_offsets_map.at(3)}, |
| 204 | + {1, in_offsets_map.at(4)}, |
| 205 | + }; |
| 206 | + |
| 207 | + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); |
| 208 | + |
| 209 | + return params; |
| 210 | + } |
| 211 | + |
| 212 | + static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) { |
| 213 | + auto params = get_default_params<kernel_selector::sdpa_params>(impl_param, is_dynamic); |
| 214 | + |
| 215 | + const auto inputs_count = 7; |
| 216 | + const auto query_layout = impl_param.get_input_layout(0); |
| 217 | + const auto key_cache_layout = impl_param.get_input_layout(3); |
| 218 | + const auto value_cache_layout = impl_param.get_input_layout(4); |
| 219 | + const auto max_context_len_layout = impl_param.get_input_layout(7); |
| 220 | + const auto context_lens_layout = impl_param.get_input_layout(8); |
| 221 | + const auto block_tables_layout = impl_param.get_input_layout(9); |
| 222 | + const auto scale_layout = impl_param.get_input_layout(10); |
| 223 | + |
| 224 | + params.inputs.resize(inputs_count); |
| 225 | + params.inputs[1] = convert_data_tensor(key_cache_layout); |
| 226 | + params.inputs[2] = convert_data_tensor(value_cache_layout); |
| 227 | + params.inputs[3] = convert_data_tensor(max_context_len_layout); |
| 228 | + params.inputs[4] = convert_data_tensor(context_lens_layout); |
| 229 | + params.inputs[5] = convert_data_tensor(block_tables_layout); |
| 230 | + params.inputs[6] = convert_data_tensor(scale_layout); |
| 231 | + |
| 232 | + params.configuration = get_sdpa_configuration(impl_param); |
| 233 | + if (!is_dynamic) { |
| 234 | + auto& constant_mem = impl_param.memory_deps; |
| 235 | + |
| 236 | + const auto max_context_len_mem = constant_mem.at(7); |
| 237 | + mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream()); |
| 238 | + |
| 239 | + const auto is_prompt_stage_mem = constant_mem.at(5); |
| 240 | + mem_lock<uint8_t, mem_lock_type::read> is_prompt_stage_mem_lock(is_prompt_stage_mem, impl_param.get_stream()); |
| 241 | + bool is_prompt_stage = is_prompt_stage_mem_lock[0]; |
| 242 | + |
| 243 | + if (is_prompt_stage) { |
| 244 | + // Use number of slots for KV cache as a maximum context length for the first iteration |
| 245 | + auto slot_mapping = impl_param.get_input_layout(6); |
| 246 | + params.configuration.max_context_len = slot_mapping.get_shape()[1]; |
| 247 | + } else { |
| 248 | + const auto max_context_len_mem = constant_mem.at(7); |
| 249 | + mem_lock<int32_t, mem_lock_type::read> max_context_len_mem_lock(max_context_len_mem, impl_param.get_stream()); |
| 250 | + params.configuration.max_context_len = max_context_len_mem_lock[0]; |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; |
| 255 | + const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; |
| 256 | + std::map<size_t, size_t> in_tensor_to_offset_map = { |
| 257 | + {0, in_offsets_map.at(0)}, |
| 258 | + {1, in_offsets_map.at(3)}, |
| 259 | + {2, in_offsets_map.at(4)}, |
| 260 | + {3, in_offsets_map.at(7)}, |
| 261 | + {4, in_offsets_map.at(8)}, |
| 262 | + {5, in_offsets_map.at(9)}, |
| 263 | + {6, in_offsets_map.at(10)}, |
| 264 | + }; |
| 265 | + std::map<size_t, size_t> out_tensor_to_offset_map = { |
| 266 | + {0, out_offsets_map.at(0)}, |
| 267 | + }; |
| 268 | + |
| 269 | + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); |
| 270 | + |
| 271 | + return params; |
| 272 | + } |
| 273 | + |
| 274 | + static std::unique_ptr<primitive_impl> create(const typed_program_node<paged_attention>& arg, const kernel_impl_params& impl_param) { |
| 275 | + std::vector<kernel_selector::kernel_data> kernels_data; |
| 276 | + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic()); |
| 277 | + auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); |
| 278 | + kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params)); |
| 279 | + |
| 280 | + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic()); |
| 281 | + auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance(); |
| 282 | + kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params)); |
| 283 | + |
| 284 | + return cldnn::make_unique<paged_attention_impl>(kernels_data); |
| 285 | + } |
| 286 | + |
| 287 | + void update_dispatch_data(const kernel_impl_params& impl_param) override { |
| 288 | + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic()); |
| 289 | + (_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]); |
| 290 | + |
| 291 | + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic()); |
| 292 | + (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]); |
| 293 | + } |
| 294 | +}; |
| 295 | + |
| 296 | +namespace detail { |
| 297 | + |
| 298 | +attach_paged_attention_impl::attach_paged_attention_impl() { |
| 299 | + auto types = { data_types::f16, data_types::f32 }; |
| 300 | + auto formats = { format::bfyx }; |
| 301 | + implementation_map<paged_attention>::add(impl_types::ocl, |
| 302 | + shape_types::dynamic_shape, |
| 303 | + paged_attention_impl::create, |
| 304 | + types, |
| 305 | + formats); |
| 306 | + |
| 307 | + implementation_map<paged_attention>::add(impl_types::ocl, |
| 308 | + shape_types::static_shape, |
| 309 | + paged_attention_impl::create, |
| 310 | + types, |
| 311 | + formats); |
| 312 | +} |
| 313 | + |
| 314 | +} // namespace detail |
| 315 | +} // namespace ocl |
| 316 | +} // namespace cldnn |
| 317 | + |
| 318 | +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::paged_attention_impl) |
| 319 | +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::paged_attention) |
0 commit comments