|
| 1 | +// Copyright (C) 2023-2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +#include "text_callback_streamer.hpp" |
| 5 | +#include "continuous_batching_impl.hpp" |
| 6 | +#include "paged_attention_transformations.hpp" |
| 7 | + |
| 8 | +namespace ov::genai { |
| 9 | +template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;}; |
| 10 | +template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>; |
| 11 | + |
| 12 | +ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl( |
| 13 | + const std::string& models_path, |
| 14 | + const Tokenizer& tokenizer, |
| 15 | + const SchedulerConfig& scheduler_config, |
| 16 | + const std::string& device, |
| 17 | + const ov::AnyMap& plugin_config) { |
| 18 | + m_tokenizer = tokenizer; |
| 19 | + ov::Core core; |
| 20 | + |
| 21 | + // The model can be compiled for GPU as well |
| 22 | + std::shared_ptr<ov::Model> model = core.read_model(models_path + "/openvino_model.xml"); |
| 23 | + |
| 24 | + DeviceConfig device_config(core, scheduler_config, device, plugin_config); |
| 25 | + |
| 26 | + apply_paged_attention_transformations(model, device_config); |
| 27 | + |
| 28 | + ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), plugin_config).create_infer_request(); |
| 29 | + |
| 30 | + // setup KV caches |
| 31 | + m_cache_manager = std::make_shared<CacheManager>(device_config, core); |
| 32 | + for (size_t decoder_layer_id = 0; decoder_layer_id < device_config.get_num_layers(); ++decoder_layer_id) { |
| 33 | + infer_request.set_input_tensor(2 + decoder_layer_id * 2, m_cache_manager->get_key_cache(decoder_layer_id)); |
| 34 | + infer_request.set_input_tensor(2 + decoder_layer_id * 2 + 1, m_cache_manager->get_value_cache(decoder_layer_id)); |
| 35 | + } |
| 36 | + |
| 37 | + SchedulerConfig updated_config = scheduler_config; |
| 38 | + // update KV number in scheduler config |
| 39 | + if (scheduler_config.num_kv_blocks != device_config.get_num_kv_blocks()) { |
| 40 | + updated_config.num_kv_blocks = device_config.get_num_kv_blocks(); |
| 41 | + } |
| 42 | + |
| 43 | + bool can_use_partial_preemption = true; |
| 44 | + if (device_config.get_device().find("GPU") != std::string::npos && !updated_config.dynamic_split_fuse) { |
| 45 | + // in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU, |
| 46 | + // as it may lead to performance slowdown |
| 47 | + can_use_partial_preemption = false; |
| 48 | + } |
| 49 | + |
| 50 | + m_scheduler = std::make_shared<Scheduler>(updated_config, can_use_partial_preemption); |
| 51 | + // and finally create model runner |
| 52 | + m_model_runner = std::make_shared<ModelRunner>(infer_request, updated_config); |
| 53 | + m_sampler = std::make_shared<Sampler>(m_tokenizer); |
| 54 | + m_sampler->set_seed(m_generation_config.rng_seed); |
| 55 | + |
| 56 | + // read default generation config |
| 57 | +} |
| 58 | + |
| 59 | +GenerationHandle |
| 60 | +ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request_id, |
| 61 | + const ov::Tensor& input_ids, |
| 62 | + ov::genai::GenerationConfig sampling_params) { |
| 63 | + sampling_params.set_eos_token_id(m_tokenizer.get_eos_token_id()); |
| 64 | + sampling_params.validate(); |
| 65 | + SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids, |
| 66 | + sampling_params, |
| 67 | + m_scheduler->get_config().block_size, |
| 68 | + m_scheduler->get_config().enable_prefix_caching); |
| 69 | + sequence_group->set_sequence_group_ptr(sequence_group); |
| 70 | + if (m_scheduler->get_config().enable_prefix_caching) { |
| 71 | + m_scheduler->restore_cached_blocks(sequence_group); |
| 72 | + } |
| 73 | + |
| 74 | + { |
| 75 | + std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex}; |
| 76 | + m_awaiting_requests.push_back(sequence_group); |
| 77 | + } |
| 78 | + return std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params); |
| 79 | +}; |
| 80 | + |
| 81 | +GenerationHandle |
| 82 | +ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request_id, |
| 83 | + const std::string& prompt, |
| 84 | + ov::genai::GenerationConfig sampling_params) { |
| 85 | + static ManualTimer timer("tokenize"); |
| 86 | + timer.start(); |
| 87 | + ov::Tensor input_ids = m_tokenizer.encode(prompt).input_ids; |
| 88 | + timer.end(); |
| 89 | + return add_request(request_id, input_ids, sampling_params); |
| 90 | +} |
| 91 | + |
| 92 | +bool ContinuousBatchingPipeline::ContinuousBatchingImpl::has_non_finished_requests() { |
| 93 | + std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex}; |
| 94 | + return !m_awaiting_requests.empty() || !m_requests.empty(); |
| 95 | +} |
| 96 | + |
| 97 | +void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { |
| 98 | + static ManualTimer step_timer("step()"); |
| 99 | + step_timer.start(); |
| 100 | + |
| 101 | + // Pull awaiting requests |
| 102 | + { |
| 103 | + std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex}; |
| 104 | + m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end()); |
| 105 | + m_awaiting_requests.clear(); |
| 106 | + } |
| 107 | + |
| 108 | + m_pipeline_metrics.requests = m_requests.size(); |
| 109 | + Scheduler::Output scheduler_output; |
| 110 | + { |
| 111 | + static ManualTimer timer("scheduling"); |
| 112 | + timer.start(); |
| 113 | + scheduler_output = m_scheduler->schedule(m_requests); |
| 114 | + m_pipeline_metrics.scheduled_requests = scheduler_output.m_scheduled_sequence_groups_ids.size(); |
| 115 | + m_pipeline_metrics.cache_usage = scheduler_output.m_cache_usage; |
| 116 | + m_cache_manager->copy_blocks(scheduler_output.m_block_copy_map); |
| 117 | + timer.end(); |
| 118 | + } |
| 119 | + |
| 120 | + // if no tokens were scheduled, we are out of memory |
| 121 | + if (scheduler_output.m_total_num_scheduled_tokens == 0) { |
| 122 | + for (size_t i = 0; i < m_requests.size(); ++i) { |
| 123 | + SequenceGroup::Ptr sequence_group = m_requests[i]; |
| 124 | + sequence_group->set_out_of_memory(); |
| 125 | + sequence_group->notify_handle(); |
| 126 | + } |
| 127 | + _free_non_running_requests(); |
| 128 | + return; |
| 129 | + } |
| 130 | + |
| 131 | + ov::Tensor logits; |
| 132 | + { |
| 133 | + static ManualTimer timer("forward"); |
| 134 | + timer.start(); |
| 135 | + logits = m_model_runner->forward(m_requests, scheduler_output); |
| 136 | + timer.end(); |
| 137 | + |
| 138 | + ov::InferRequest infer_request = m_model_runner->get_infer_request(); |
| 139 | + ov::CompiledModel compiled_model = infer_request.get_compiled_model(); |
| 140 | + const bool is_profiling_enabled = compiled_model.get_property(ov::enable_profiling); |
| 141 | + |
| 142 | + // collect detailed statistic |
| 143 | + if (is_profiling_enabled) { |
| 144 | + std::vector<ov::ProfilingInfo> profiling_info = m_model_runner->get_infer_request().get_profiling_info(); |
| 145 | + for (const ov::ProfilingInfo& info : profiling_info) { |
| 146 | + double current_time = info.real_time.count(); |
| 147 | + if (info.node_type == "PagedAttentionExtension") { |
| 148 | + m_perf.m_paged_attention_time_ms += current_time; |
| 149 | + } else if (info.node_type == "FullyConnected") { |
| 150 | + m_perf.m_matmul_time_ms += current_time; |
| 151 | + } |
| 152 | + m_perf.m_infer_total_ms += current_time; |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + SamplerOutput sampler_output; |
| 158 | + { |
| 159 | + static ManualTimer timer("sample"); |
| 160 | + timer.start(); |
| 161 | + sampler_output = m_sampler->sample(m_requests, logits); |
| 162 | + timer.end(); |
| 163 | + } |
| 164 | + |
| 165 | + // process sampler_output (e.g. fork or drop sequences from BlockScheduler) |
| 166 | + { |
| 167 | + static ManualTimer timer("fork / free sequence"); |
| 168 | + timer.start(); |
| 169 | + |
| 170 | + for (const auto& pair : sampler_output.m_forked_sequences) { |
| 171 | + uint64_t parent_id = pair.first; |
| 172 | + const std::list<uint64_t>& child_ids = pair.second; |
| 173 | + for (auto & child_id : child_ids) |
| 174 | + m_scheduler->fork_sequence(parent_id, child_id); |
| 175 | + } |
| 176 | + |
| 177 | + for (auto seq_id : sampler_output.m_dropped_sequences) |
| 178 | + m_scheduler->free_sequence(seq_id); |
| 179 | + |
| 180 | + timer.end(); |
| 181 | + } |
| 182 | + |
| 183 | + // notify requests dropped by handle |
| 184 | + { |
| 185 | + static ManualTimer timer("notify requests dropped by handle"); |
| 186 | + timer.start(); |
| 187 | + _notify_requests_dropped_by_handle(); |
| 188 | + timer.end(); |
| 189 | + } |
| 190 | + |
| 191 | + // free non running requests for current step |
| 192 | + |
| 193 | + { |
| 194 | + static ManualTimer timer("free non running requests"); |
| 195 | + timer.start(); |
| 196 | + _free_non_running_requests(); |
| 197 | + timer.end(); |
| 198 | + } |
| 199 | + |
| 200 | + step_timer.end(); |
| 201 | +} |
| 202 | + |
| 203 | +std::vector<EncodedGenerationResult> |
| 204 | +ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<ov::Tensor>& input_ids, |
| 205 | + const std::vector<GenerationConfig>& sampling_params, |
| 206 | + const StreamerVariant& streamer) { |
| 207 | + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); |
| 208 | + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); |
| 209 | + const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{ |
| 210 | + [](std::monostate) -> std::shared_ptr<StreamerBase> { |
| 211 | + return nullptr; |
| 212 | + }, |
| 213 | + [](const std::shared_ptr<StreamerBase>& streamer) { |
| 214 | + return streamer; |
| 215 | + }, |
| 216 | + [this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> { |
| 217 | + return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer); |
| 218 | + } |
| 219 | + }, streamer); |
| 220 | + |
| 221 | + std::vector<GenerationHandle> generations; |
| 222 | + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { |
| 223 | + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); |
| 224 | + generations.push_back(add_request(request_id, input_ids[request_id], sampling_params[request_id])); |
| 225 | + } |
| 226 | + |
| 227 | + std::vector<EncodedGenerationResult> results; |
| 228 | + results.reserve(m_awaiting_requests.size()); |
| 229 | + |
| 230 | + bool continue_generation = true; |
| 231 | + while (has_non_finished_requests() && continue_generation) { |
| 232 | + step(); |
| 233 | + if (streamer_ptr) { |
| 234 | + std::unordered_map<uint64_t, GenerationOutput> token = generations.at(0).get()->back(); |
| 235 | + OPENVINO_ASSERT(1 == token.size()); |
| 236 | + OPENVINO_ASSERT(1 == token.begin()->second.generated_ids.size()); |
| 237 | + continue_generation = !streamer_ptr->put(token.begin()->second.generated_ids.at(0)); |
| 238 | + } |
| 239 | + } |
| 240 | + if (streamer_ptr) { |
| 241 | + streamer_ptr->end(); |
| 242 | + } |
| 243 | + |
| 244 | + for (size_t generation_idx = 0; generation_idx < generations.size(); ++generation_idx) { |
| 245 | + const auto& generation = generations[generation_idx]; |
| 246 | + EncodedGenerationResult result; |
| 247 | + result.m_request_id = 1; |
| 248 | + std::vector<GenerationOutput> generation_outputs = generation->read_all(); |
| 249 | + std::sort(generation_outputs.begin(), generation_outputs.end(), [=] (GenerationOutput& r1, GenerationOutput& r2) { |
| 250 | + return r1.score > r2.score; |
| 251 | + }); |
| 252 | + |
| 253 | + auto num_outputs = std::min(sampling_params[generation_idx].num_return_sequences, generation_outputs.size()); |
| 254 | + for (size_t generation_output_idx = 0; generation_output_idx < num_outputs; ++generation_output_idx) { |
| 255 | + const auto& generation_output = generation_outputs[generation_output_idx]; |
| 256 | + result.m_generation_ids.push_back(std::move(generation_output.generated_ids)); |
| 257 | + result.m_scores.push_back(generation_output.score); |
| 258 | + } |
| 259 | + result.m_status = generation->get_status(); |
| 260 | + results.push_back(std::move(result)); |
| 261 | + } |
| 262 | + |
| 263 | + OPENVINO_ASSERT(results.size() == input_ids.size()); |
| 264 | + return results; |
| 265 | +} |
| 266 | + |
| 267 | +std::vector<GenerationResult> |
| 268 | +ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<std::string>& prompts, |
| 269 | + std::vector<ov::genai::GenerationConfig> sampling_params, |
| 270 | + const StreamerVariant& streamer) { |
| 271 | + std::vector<ov::Tensor> input_ids; |
| 272 | + static ManualTimer timer("tokenize"); |
| 273 | + if (m_is_chat_conversation) { |
| 274 | + OPENVINO_ASSERT(1 == prompts.size(), "Can't chat with multiple prompts"); |
| 275 | + m_history.push_back({{"role", "user"}, {"content", prompts.at(0)}}); |
| 276 | + constexpr bool add_generation_prompt = true; |
| 277 | + std::string history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); |
| 278 | + timer.start(); |
| 279 | + input_ids.push_back(m_tokenizer.encode(history).input_ids); |
| 280 | + timer.end(); |
| 281 | + } else { |
| 282 | + input_ids.reserve(prompts.size()); |
| 283 | + for (const std::string& prompt : prompts) { |
| 284 | + timer.start(); |
| 285 | + input_ids.push_back(m_tokenizer.encode(prompt).input_ids); |
| 286 | + timer.end(); |
| 287 | + } |
| 288 | + } |
| 289 | + std::vector<EncodedGenerationResult> encoded = generate(input_ids, sampling_params, streamer); |
| 290 | + std::vector<GenerationResult> decoded; |
| 291 | + decoded.reserve(encoded.size()); |
| 292 | + for (EncodedGenerationResult& res : encoded) { |
| 293 | + std::vector<std::string> generated; |
| 294 | + generated.reserve(res.m_generation_ids.size()); |
| 295 | + for (size_t idx = 0; idx < res.m_generation_ids.size(); ++idx) { |
| 296 | + generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); |
| 297 | + if (m_is_chat_conversation && 0 == idx) { |
| 298 | + m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); |
| 299 | + } |
| 300 | + } |
| 301 | + decoded.push_back(GenerationResult{ |
| 302 | + res.m_request_id, |
| 303 | + std::move(generated), |
| 304 | + std::move(res.m_scores), |
| 305 | + res.m_status |
| 306 | + }); |
| 307 | + } |
| 308 | + return decoded; |
| 309 | +} |
| 310 | + |
| 311 | +void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_requests() { |
| 312 | + std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin(); |
| 313 | + while (requests_iterator != m_requests.end()) { |
| 314 | + const auto& request = *requests_iterator; |
| 315 | + if(request->has_finished() || request->out_of_memory() || request->handle_dropped()) { |
| 316 | + for (const auto& sequence: request->get_sequences()) { |
| 317 | + m_scheduler->free_sequence(sequence->get_id()); |
| 318 | + } |
| 319 | + m_sampler->clear_beam_search_info(request->get_request_id()); |
| 320 | + requests_iterator = m_requests.erase(requests_iterator); |
| 321 | + } else { |
| 322 | + requests_iterator++; |
| 323 | + } |
| 324 | + } |
| 325 | +} |
| 326 | + |
| 327 | +void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_dropped_by_handle() { |
| 328 | + // Notify the last time by pushing empty output |
| 329 | + // This causes read() to unblock by adding anything to the queue |
| 330 | + for (SequenceGroup::Ptr& request : m_requests) { |
| 331 | + if (request->handle_dropped()) |
| 332 | + request->push_empty_outputs(); |
| 333 | + } |
| 334 | +} |
| 335 | +} |
0 commit comments