From 4a9ca61e0e5afd7d96c5236fb9106778d1797237 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Fri, 14 Feb 2025 17:31:01 +0100 Subject: [PATCH 1/9] Use get_max_new_tokens() insted of max_new_tokens field when stopping generation --- src/cpp/src/continuous_batching_impl.cpp | 2 +- .../prompt_lookup/continuous_batching_for_prompt_lookup.cpp | 2 +- src/cpp/src/sampler.cpp | 1 + src/cpp/src/sampler.hpp | 1 + src/cpp/src/sequence_group.hpp | 2 +- .../continuous_batching_for_speculative_decoding_impl.cpp | 6 +++--- .../src/speculative_decoding/speculative_decoding_impl.cpp | 6 +++--- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 81b0b9ca8c..651a3e8376 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -752,7 +752,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( } currently_processed_tokens += output_seq_len * num_running_sequences; // For max_new_tokens == 0, we don't reach sampling so need to notify handle separately - if(sequence_group->get_sampling_parameters().max_new_tokens == 0) { + if(sequence_group->get_sampling_parameters().get_max_new_tokens(sequence_group->get_prompt_len()) == 0) { sequence_group->notify_handle_echo_only(); } } diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index aa4ea8a53a..9cafbac171 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate const auto sampling_params = request->get_sampling_parameters(); { const auto generated_len = running_sequence->get_generated_len(); - const auto left_generated_len = std::min(sampling_params.max_new_tokens, sampling_params.max_length) - generated_len - 1; + const auto left_generated_len = std::min(sampling_params.get_max_new_tokens(request->get_prompt_len()), sampling_params.max_length) - generated_len - 1; min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); } TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 3a7fd70ea5..7bc075ff4c 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -222,6 +222,7 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group, // to avoid selecting the same tokens for beams within group, let's just initialize score // for the front one group.ongoing.front().m_score = 0.0f; + group.prompt_len = this->m_sequence_group->get_prompt_len(); } } diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index c53676d23c..8843fe0a2c 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -140,6 +140,7 @@ class Sampler::GroupBeamSearcher { struct Group { std::vector ongoing; // Best beams in front std::vector min_heap; // The worst of the best completed beams is the first + size_t prompt_len; bool done = false; int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index c2861cdf18..7cd35e217b 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -457,7 +457,7 @@ class SequenceGroup : public std::enable_shared_from_this { } bool requires_sampling() const { - return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0; + return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.get_max_new_tokens(get_prompt_len()) > 0; } void schedule_tokens(size_t num_tokens) { diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index aabde5c0df..d117e55882 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -260,7 +260,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update const size_t num_processed_tokens = request->get_num_processed_tokens(), prompt_len = request->get_prompt_len(), updated_context_len = min_candidate_len + prompt_len, - max_new_tokens = request->get_sampling_parameters().max_new_tokens; + max_new_tokens = request->get_sampling_parameters().get_max_new_tokens(request->get_prompt_len()); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; if (generated_len > 0 && result.removed_tokens_cnt > 0) { request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); @@ -323,13 +323,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m // generate only one token in case of non speculative decoding request->pause_generation(true); } else if (request->get_num_processed_tokens() >= request->get_prompt_len() && - (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.max_new_tokens - 1) { + (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.get_max_new_tokens(request->get_prompt_len()) - 1) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) { request->pause_generation(true); } else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) { request->pause_generation(true); - } else if (sampling_params.max_new_tokens == 0) { + } else if (sampling_params.get_max_new_tokens(request->get_prompt_len()) == 0) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == request->get_prompt_len()) { request->pause_generation(true); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 2d0461a5c1..50c8fdcf13 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -95,7 +95,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) { - m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(input_ids.get_size())); std::lock_guard lock(m_draft_generations_mutex); auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; @@ -108,7 +108,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { - m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(input_ids.get_size())); std::lock_guard lock(m_draft_generations_mutex); auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; @@ -240,7 +240,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].get_max_new_tokens(input_ids.get_size())); OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); From 02c86bc32e819d1a7b03ad84cf0cbff53f5ca097 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Fri, 14 Feb 2025 17:37:28 +0100 Subject: [PATCH 2/9] fix --- src/cpp/src/sampler.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 7bc075ff4c..191850c17c 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -1037,7 +1037,7 @@ void Sampler::GroupBeamSearcher::Group::is_done() { return; } case ov::genai::StopCriteria::NEVER: { - size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens() : cur_len; + size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens(this->prompt_len) : cur_len; float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty); done = worst_score >= highest_attainable_score; return; From 4862030ff34be8f8a6e4254a40fb2cc0bf936446 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Mon, 24 Feb 2025 20:23:48 +0100 Subject: [PATCH 3/9] Addressing review comment --- src/cpp/src/continuous_batching_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 651a3e8376..dd93677847 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -752,7 +752,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( } currently_processed_tokens += output_seq_len * num_running_sequences; // For max_new_tokens == 0, we don't reach sampling so need to notify handle separately - if(sequence_group->get_sampling_parameters().get_max_new_tokens(sequence_group->get_prompt_len()) == 0) { + if(sequence_group->get_max_new_tokens() == 0) { sequence_group->notify_handle_echo_only(); } } From fa350ad681bb1180bc196fcbffd5567398292677 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Mon, 3 Mar 2025 20:24:32 +0100 Subject: [PATCH 4/9] fix --- .../prompt_lookup/continuous_batching_for_prompt_lookup.cpp | 2 +- src/cpp/src/sampler.cpp | 2 +- src/cpp/src/sampler.hpp | 1 - src/cpp/src/sequence_group.hpp | 2 +- .../continuous_batching_for_speculative_decoding_impl.cpp | 6 +++--- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index 9cafbac171..6e1b31903b 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate const auto sampling_params = request->get_sampling_parameters(); { const auto generated_len = running_sequence->get_generated_len(); - const auto left_generated_len = std::min(sampling_params.get_max_new_tokens(request->get_prompt_len()), sampling_params.max_length) - generated_len - 1; + const auto left_generated_len = std::min(request->get_max_new_tokens(), sampling_params.max_length) - generated_len - 1; min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); } TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 191850c17c..7bc075ff4c 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -1037,7 +1037,7 @@ void Sampler::GroupBeamSearcher::Group::is_done() { return; } case ov::genai::StopCriteria::NEVER: { - size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens(this->prompt_len) : cur_len; + size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens() : cur_len; float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty); done = worst_score >= highest_attainable_score; return; diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 8843fe0a2c..c53676d23c 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -140,7 +140,6 @@ class Sampler::GroupBeamSearcher { struct Group { std::vector ongoing; // Best beams in front std::vector min_heap; // The worst of the best completed beams is the first - size_t prompt_len; bool done = false; int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 7cd35e217b..d4a48126aa 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -457,7 +457,7 @@ class SequenceGroup : public std::enable_shared_from_this { } bool requires_sampling() const { - return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.get_max_new_tokens(get_prompt_len()) > 0; + return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && get_max_new_tokens() > 0; } void schedule_tokens(size_t num_tokens) { diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index d117e55882..6e8ca340ae 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -260,7 +260,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update const size_t num_processed_tokens = request->get_num_processed_tokens(), prompt_len = request->get_prompt_len(), updated_context_len = min_candidate_len + prompt_len, - max_new_tokens = request->get_sampling_parameters().get_max_new_tokens(request->get_prompt_len()); + max_new_tokens = request->get_max_new_tokens();; size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; if (generated_len > 0 && result.removed_tokens_cnt > 0) { request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); @@ -323,13 +323,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m // generate only one token in case of non speculative decoding request->pause_generation(true); } else if (request->get_num_processed_tokens() >= request->get_prompt_len() && - (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.get_max_new_tokens(request->get_prompt_len()) - 1) { + (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= request->get_max_new_tokens() - 1) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) { request->pause_generation(true); } else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) { request->pause_generation(true); - } else if (sampling_params.get_max_new_tokens(request->get_prompt_len()) == 0) { + } else if (request->get_max_new_tokens() == 0) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == request->get_prompt_len()) { request->pause_generation(true); From 4b153d91e86dd3a2efb21fda317fceb738aa2c9b Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Tue, 4 Mar 2025 09:43:34 +0100 Subject: [PATCH 5/9] fix --- src/cpp/src/sampler.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 7bc075ff4c..3a7fd70ea5 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -222,7 +222,6 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group, // to avoid selecting the same tokens for beams within group, let's just initialize score // for the front one group.ongoing.front().m_score = 0.0f; - group.prompt_len = this->m_sequence_group->get_prompt_len(); } } From 060783eafb7eab3d9c462c74fb727fde4e1f7de1 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Tue, 4 Mar 2025 10:22:06 +0100 Subject: [PATCH 6/9] address review comments --- .../src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp | 2 +- src/cpp/src/sequence_group.hpp | 2 +- .../continuous_batching_for_speculative_decoding_impl.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index 6e1b31903b..0ee24a92d8 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate const auto sampling_params = request->get_sampling_parameters(); { const auto generated_len = running_sequence->get_generated_len(); - const auto left_generated_len = std::min(request->get_max_new_tokens(), sampling_params.max_length) - generated_len - 1; + const auto left_generated_len = std::min(request->get_max_new_tokens(), sampling_params.max_length - request->get_prompt_len()) - generated_len - 1; min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); } TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index d4a48126aa..4dc6e7e06d 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -699,7 +699,7 @@ class SequenceGroup : public std::enable_shared_from_this { m_generation_stream->push(std::move(outputs)); } - size_t get_max_new_tokens() { + size_t get_max_new_tokens() const { return m_sampling_params.get_max_new_tokens(get_prompt_len()); } }; diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 6e8ca340ae..5909a41abe 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -260,7 +260,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update const size_t num_processed_tokens = request->get_num_processed_tokens(), prompt_len = request->get_prompt_len(), updated_context_len = min_candidate_len + prompt_len, - max_new_tokens = request->get_max_new_tokens();; + max_new_tokens = request->get_max_new_tokens(); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; if (generated_len > 0 && result.removed_tokens_cnt > 0) { request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); From b72b59bb5b9737bb9bbcf25be34aea0346ef1d93 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Tue, 4 Mar 2025 11:33:33 +0100 Subject: [PATCH 7/9] fix --- .../src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index 0ee24a92d8..45401a1825 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate const auto sampling_params = request->get_sampling_parameters(); { const auto generated_len = running_sequence->get_generated_len(); - const auto left_generated_len = std::min(request->get_max_new_tokens(), sampling_params.max_length - request->get_prompt_len()) - generated_len - 1; + const auto left_generated_len = request->get_max_new_tokens() - generated_len - 1; min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); } TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); From 4cd4fe625520f659af57e4991ed8011e22d77610 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Tue, 4 Mar 2025 16:03:05 +0100 Subject: [PATCH 8/9] fix --- src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 50c8fdcf13..0db71b7917 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -240,7 +240,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].get_max_new_tokens(input_ids.get_size())); + m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].get_max_new_tokens(input_ids[request_id].get_size())); OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id])); From 5075eb744231a8665c58587324a374c25d99cc38 Mon Sep 17 00:00:00 2001 From: michalkulakowski Date: Wed, 5 Mar 2025 09:22:11 +0100 Subject: [PATCH 9/9] fix --- src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 0db71b7917..586b9f34e3 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -108,7 +108,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { - m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(input_ids.get_size())); + m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(prompt.length())); std::lock_guard lock(m_draft_generations_mutex); auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true;