Skip to content

Commit c1d1672

Browse files
Address review comments
1 parent 8454e24 commit c1d1672

6 files changed

+11
-9
lines changed

src/cpp/src/continuous_batching_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
522522
}
523523
currently_processed_tokens += output_seq_len * num_running_sequences;
524524
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
525-
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
525+
if(sequence_group->get_sampling_parameters().get_max_new_tokens(sequence_group->get_prompt_len()) == 0) {
526526
sequence_group->notify_handle_echo_only();
527527
}
528528
}

src/cpp/src/prompt_lookup/continuous_batching_for_prompt_lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForPromptLookupImpl::generate
6767
const auto sampling_params = request->get_sampling_parameters();
6868
{
6969
const auto generated_len = running_sequence->get_generated_len();
70-
const auto left_generated_len = std::min(sampling_params.max_new_tokens, sampling_params.max_length) - generated_len - 1;
70+
const auto left_generated_len = std::min(sampling_params.get_max_new_tokens(request->get_prompt_len()), sampling_params.max_length) - generated_len - 1;
7171
min_num_assistant_tokens = std::min(sampling_params.num_assistant_tokens, left_generated_len);
7272
}
7373
TokenIds candidates = generate_candidates(full_input_ids, min_num_assistant_tokens, sampling_params.max_ngram_size);

src/cpp/src/sampler.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ Sampler::GroupBeamSearcher::GroupBeamSearcher(SequenceGroup::Ptr sequence_group,
215215
// to avoid selecting the same tokens for beams within group, let's just initialize score
216216
// for the front one
217217
group.ongoing.front().m_score = 0.0f;
218+
group.prompt_len = this->m_sequence_group->get_prompt_len();
218219
}
219220
}
220221

@@ -408,7 +409,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits,
408409
}
409410

410411
// check whether group has finished
411-
group.is_done(m_parameters, this->m_sequence_group->get_prompt_len());
412+
group.is_done(m_parameters);
412413

413414
// group cannot continue if there are no valid child beams
414415
if (child_beams_per_group[group_id].size() == 0) {
@@ -956,7 +957,7 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge
956957
return preeempted_sequence_id;
957958
}
958959

959-
void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len) {
960+
void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
960961
assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
961962
"number of beams should be divisible by number of groups");
962963
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;

src/cpp/src/sampler.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,11 @@ class Sampler::GroupBeamSearcher {
111111
struct Group {
112112
std::vector<Beam> ongoing; // Best beams in front
113113
std::vector<Beam> min_heap; // The worst of the best completed beams is the first
114+
size_t prompt_len;
114115
bool done = false;
115116

116117
int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
117-
void is_done(const ov::genai::GenerationConfig& sampling_params, size_t prompt_len);
118+
void is_done(const ov::genai::GenerationConfig& sampling_params);
118119
};
119120

120121
SequenceGroup::Ptr m_sequence_group;

src/cpp/src/sequence_group.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
457457
}
458458

459459
bool requires_sampling() const {
460-
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.max_new_tokens > 0;
460+
return get_context_len() >= get_prompt_len() && get_context_len() > m_max_content_len && m_sampling_params.get_max_new_tokens(this->get_prompt_len()) > 0;
461461
}
462462

463463
void schedule_tokens(size_t num_tokens) {

src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update
261261
const size_t num_processed_tokens = request->get_num_processed_tokens(),
262262
prompt_len = request->get_prompt_len(),
263263
updated_context_len = min_candidate_len + prompt_len,
264-
max_new_tokens = request->get_sampling_parameters().max_new_tokens;
264+
max_new_tokens = request->get_sampling_parameters().get_max_new_tokens(request->get_prompt_len());
265265
size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0;
266266
if (generated_len > 0 && result.removed_tokens_cnt > 0) {
267267
request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1);
@@ -324,13 +324,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m
324324
// generate only one token in case of non speculative decoding
325325
request->pause_generation(true);
326326
} else if (request->get_num_processed_tokens() >= request->get_prompt_len() &&
327-
(request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.max_new_tokens - 1) {
327+
(request->get_num_processed_tokens() - request->get_prompt_len() + 1) >= sampling_params.get_max_new_tokens(request->get_prompt_len()) - 1) {
328328
request->pause_generation(true);
329329
} else if (request->get_num_processed_tokens() == 0 && sampling_params.num_return_sequences > 1) {
330330
request->pause_generation(true);
331331
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0.f) {
332332
request->pause_generation(true);
333-
} else if (sampling_params.max_new_tokens == 0) {
333+
} else if (sampling_params.get_max_new_tokens(request->get_prompt_len()) == 0) {
334334
request->pause_generation(true);
335335
} else if (request->get_num_processed_tokens() == request->get_prompt_len()) {
336336
request->pause_generation(true);

0 commit comments

Comments
 (0)