Skip to content

Commit 15fe46e

Browse files
authored
[Speculative decoding] Fix draft_model tun in case of long prompt (openvinotoolkit#1114)
Ticket: * CVS-156390 Details: in case of big prompt draft_model request was marked as paused, so not scheduled
1 parent e947742 commit 15fe46e

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

src/cpp/src/sampler.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ void register_new_token(const Token& sampled_token_id,
586586
running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob);
587587
}
588588
if (!is_validation_mode_enabled &&
589-
std::fabs(sampled_token_id.m_log_prob) < logit_processor.get_assistant_confidence_threshold()) {
589+
std::fabs(std::exp(sampled_token_id.m_log_prob)) < logit_processor.get_assistant_confidence_threshold()) {
590590
auto sequence_group = running_sequence->get_sequence_group_ptr();
591591
sequence_group->pause_generation(true);
592592
}

src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update
254254
prompt_len = request->get_prompt_len(),
255255
updated_context_len = min_candidate_len + prompt_len,
256256
max_new_tokens = request->get_sampling_parameters().max_new_tokens;
257+
// prompt phase
258+
if (request->get_context_len() < request->get_prompt_len() && result.inserted_tokens_cnt == 0) {
259+
return result;
260+
}
257261
size_t generated_len = request->get_context_len() - request->get_prompt_len();
258262
if (num_processed_tokens > 0) {
259263
request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt);

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(
4848
auto k = static_cast<float>(draft_model_cache_size) / (main_model_cache_size + draft_model_cache_size);
4949

5050
size_t main_cache_size = main_scheduler_config.cache_size * (1 - k),
51-
draft_cache_size = main_scheduler_config.cache_size * k;
51+
draft_cache_size = main_scheduler_config.cache_size - main_cache_size;
5252
if (draft_cache_size == 0) {
5353
main_cache_size -= main_cache_size > 1 ? 1 : 0;
5454
draft_cache_size = 1;
@@ -158,6 +158,10 @@ void ContinuousBatchingPipeline::SpeculativeDecodingImpl::step() {
158158
m_draft_generations.erase(request_id);
159159
}
160160
auto updated_seq_info = update_sequence_info[request_id];
161+
// several prompt phase
162+
if (updated_seq_info.inserted_tokens_cnt == 0) {
163+
continue;
164+
}
161165
float acceptance_rate = 1 - static_cast<float>(updated_seq_info.removed_tokens_cnt) / updated_seq_info.inserted_tokens_cnt;
162166
m_sd_metrics.update_acceptance_rate(request_id, acceptance_rate * 100);
163167
m_sd_metrics.update_draft_accepted_tokens(request_id, (updated_seq_info.inserted_tokens_cnt - updated_seq_info.removed_tokens_cnt));
@@ -203,6 +207,10 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
203207
while (has_non_finished_requests() && continue_generation) {
204208
step();
205209
if (streamer_ptr) {
210+
// not generated tokens like several prompt phase
211+
if (!main_generations.at(0).get()->can_read()) {
212+
continue;
213+
}
206214
std::unordered_map<uint64_t, GenerationOutput> token = main_generations.at(0).get()->back();
207215
OPENVINO_ASSERT(1 <= token.size());
208216
OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size());

0 commit comments

Comments
 (0)