diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 81b0b9ca8c..dd93677847 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -752,7 +752,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs( } currently_processed_tokens += output_seq_len * num_running_sequences; // For max_new_tokens == 0, we don't reach sampling so need to notify handle separately - if(sequence_group->get_sampling_parameters().max_new_tokens == 0) { + if(sequence_group->get_max_new_tokens() == 0) { sequence_group->notify_handle_echo_only(); } } diff --git a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp index aa4ea8a53a..45401a1825 100644 --- a/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp +++ b/src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp @@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate const auto sampling_params = request->get_sampling_parameters(); { const auto generated_len = running_sequence->get_generated_len(); - const auto left_generated_len = std::min(sampling_params.max_new_tokens, sampling_params.max_length) - generated_len - 1; + const auto left_generated_len = request->get_max_new_tokens() - generated_len - 1; min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len); } TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index c2861cdf18..4dc6e7e06d 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -457,7 +457,7 @@ class SequenceGroup : public std::enable_shared_from_this { } bool requires_sampling() const { - return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0; + return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && get_max_new_tokens() > 0; } void schedule_tokens(size_t num_tokens) { @@ -699,7 +699,7 @@ class SequenceGroup : public std::enable_shared_from_this { m_generation_stream->push(std::move(outputs)); } - size_t get_max_new_tokens() { + size_t get_max_new_tokens() const { return m_sampling_params.get_max_new_tokens(get_prompt_len()); } }; diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index aabde5c0df..5909a41abe 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -260,7 +260,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update const size_t num_processed_tokens = request->get_num_processed_tokens(), prompt_len = request->get_prompt_len(), updated_context_len = min_candidate_len + prompt_len, - max_new_tokens = request->get_sampling_parameters().max_new_tokens; + max_new_tokens = request->get_max_new_tokens(); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; if (generated_len > 0 && result.removed_tokens_cnt > 0) { request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); @@ -323,13 +323,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m // generate only one token in case of non speculative decoding request->pause_generation(true); } else if (request->get_num_processed_tokens() >= request->get_prompt_len() && - (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.max_new_tokens - 1) { + (request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= request->get_max_new_tokens() - 1) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) { request->pause_generation(true); } else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) { request->pause_generation(true); - } else if (sampling_params.max_new_tokens == 0) { + } else if (request->get_max_new_tokens() == 0) { request->pause_generation(true); } else if (request->get_num_processed_tokens() == request->get_prompt_len()) { request->pause_generation(true); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 2d0461a5c1..586b9f34e3 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -95,7 +95,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const ov::Tensor& input_ids, ov::genai::GenerationConfig sampling_params) { - m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(input_ids.get_size())); std::lock_guard lock(m_draft_generations_mutex); auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; @@ -108,7 +108,7 @@ GenerationHandle ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) { - m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(prompt.length())); std::lock_guard lock(m_draft_generations_mutex); auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; @@ -240,7 +240,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< std::vector main_generations; for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].max_new_tokens); + m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].get_max_new_tokens(input_ids[request_id].get_size())); OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id]));