Skip to content

Commit 0214ba8

Browse files
Use get_max_new_tokens() insted of max_new_tokens field when stopping… (#1417)
… generation
1 parent a31c003 commit 0214ba8

5 files changed

+10
-10
lines changed

src/cpp/src/continuous_batching_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
771771
}
772772
currently_processed_tokens += output_seq_len * num_running_sequences;
773773
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
774-
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
774+
if(sequence_group->get_max_new_tokens() == 0) {
775775
sequence_group->notify_handle_echo_only();
776776
}
777777
}

src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate
6767
const auto sampling_params = request->get_sampling_parameters();
6868
{
6969
const auto generated_len = running_sequence->get_generated_len();
70-
const auto left_generated_len = std::min(sampling_params.max_new_tokens, sampling_params.max_length) - generated_len - 1;
70+
const auto left_generated_len = request->get_max_new_tokens() - generated_len - 1;
7171
min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len);
7272
}
7373
TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size);

src/cpp/src/sequence_group.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
492492
}
493493

494494
bool requires_sampling() const {
495-
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0;
495+
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && get_max_new_tokens() > 0;
496496
}
497497

498498
void schedule_tokens(size_t num_tokens) {
@@ -749,7 +749,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
749749
m_generation_stream->push(std::move(outputs));
750750
}
751751

752-
size_t get_max_new_tokens() {
752+
size_t get_max_new_tokens() const {
753753
return m_sampling_params.get_max_new_tokens(get_prompt_len());
754754
}
755755
};

src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update
260260
const size_t num_processed_tokens = request->get_num_processed_tokens(),
261261
prompt_len = request->get_prompt_len(),
262262
updated_context_len = min_candidate_len + prompt_len,
263-
max_new_tokens = request->get_sampling_parameters().max_new_tokens;
263+
max_new_tokens = request->get_max_new_tokens();
264264
size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0;
265265
if (generated_len > 0 && result.removed_tokens_cnt > 0) {
266266
request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1);
@@ -323,13 +323,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m
323323
// generate only one token in case of non speculative decoding
324324
request->pause_generation(true);
325325
} else if (request->get_num_processed_tokens() >= request->get_prompt_len() &&
326-
(request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.max_new_tokens - 1) {
326+
(request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= request->get_max_new_tokens() - 1) {
327327
request->pause_generation(true);
328328
} else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) {
329329
request->pause_generation(true);
330330
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) {
331331
request->pause_generation(true);
332-
} else if (sampling_params.max_new_tokens == 0) {
332+
} else if (request->get_max_new_tokens() == 0) {
333333
request->pause_generation(true);
334334
} else if (request->get_num_processed_tokens() == request->get_prompt_len()) {
335335
request->pause_generation(true);

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ GenerationHandle
9999
ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id,
100100
const ov::Tensor& input_ids,
101101
ov::genai::GenerationConfig sampling_params) {
102-
m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens);
102+
m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(input_ids.get_size()));
103103
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
104104
auto draft_sampling_params = sampling_params;
105105
draft_sampling_params.ignore_eos = true;
@@ -112,7 +112,7 @@ GenerationHandle
112112
ContinuousBatchingPipeline::SpeculativeDecodingImpl::add_request(uint64_t request_id,
113113
const std::string& prompt,
114114
ov::genai::GenerationConfig sampling_params) {
115-
m_sd_metrics.set_generated_len(request_id, sampling_params.max_new_tokens);
115+
m_sd_metrics.set_generated_len(request_id, sampling_params.get_max_new_tokens(prompt.length()));
116116
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
117117
auto draft_sampling_params = sampling_params;
118118
draft_sampling_params.ignore_eos = true;
@@ -245,7 +245,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
245245

246246
std::vector<GenerationHandle> main_generations;
247247
for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) {
248-
m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].max_new_tokens);
248+
m_sd_metrics.set_generated_len(request_id, sampling_params[request_id].get_max_new_tokens(input_ids[request_id].get_size()));
249249
OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch.");
250250
main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], sampling_params[request_id]));
251251

0 commit comments

Comments
 (0)