Skip to content

Commit dbb8c07

Browse files
committed
Fix for CI
1 parent b7cf039 commit dbb8c07

4 files changed

+14
-23
lines changed

src/cpp/src/logit_processor.hpp

-6
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ class LogitProcessor {
312312

313313
// speculative decoding parameters
314314
float m_assistant_confidence_threshold = 0.f;
315-
bool m_is_dynamic_speculative_decoding = false;
316315

317316

318317
public:
@@ -360,16 +359,11 @@ class LogitProcessor {
360359
}
361360
}
362361
if (sampling_params.assistant_confidence_threshold > 0) {
363-
m_is_dynamic_speculative_decoding = true;
364362
m_assistant_confidence_threshold = sampling_params.assistant_confidence_threshold;
365363
}
366364
}
367365
}
368366

369-
bool is_dynamic_speculative_decoding() {
370-
return m_is_dynamic_speculative_decoding;
371-
}
372-
373367
float get_assistant_confidence_threshold() {
374368
return m_assistant_confidence_threshold;
375369
}

src/cpp/src/sampler.cpp

+10-15
Original file line numberDiff line numberDiff line change
@@ -565,23 +565,12 @@ void register_new_token(const Token& sampled_token_id,
565565
Sequence::Ptr running_sequence,
566566
LogitProcessor& logit_processor,
567567
bool is_extend_sequence,
568-
bool is_update_len_logit_processor,
569568
bool is_validation_mode_enabled) {
570569
logit_processor.register_new_generated_token(sampled_token_id.m_index);
571-
size_t generated_len = logit_processor.get_generated_len();
572570
if (is_extend_sequence) {
573571
running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob);
574-
} else {
575-
// just update the token log prob in case of successfully validated token
576-
OPENVINO_ASSERT(generated_len < running_sequence->get_generated_len());
577-
running_sequence->update_generated_log_prob(generated_len, sampled_token_id.m_log_prob);
578-
}
579-
// increment seq len only for one sequence in sequence group to sync them
580-
if (is_update_len_logit_processor) {
581-
logit_processor.update_generated_len(++generated_len);
582572
}
583573
if (!is_validation_mode_enabled &&
584-
logit_processor.is_dynamic_speculative_decoding() &&
585574
std::fabs(sampled_token_id.m_log_prob) < logit_processor.get_assistant_confidence_threshold()) {
586575
auto sequence_group = running_sequence->get_sequence_group_ptr();
587576
sequence_group->pause_generation(true);
@@ -604,7 +593,7 @@ create_n_forked_sequences(SequenceGroup::Ptr sequence_group,
604593
const auto forked_sequence = sequence_group->fork_sequence(sequence_to_fork);
605594
const auto forked_seq_id = forked_sequence->get_id();
606595
forked_seq_ids.push_back(forked_seq_id);
607-
register_new_token(sampled_tokens[i], forked_sequence, logit_processor, true, false, false);
596+
register_new_token(sampled_tokens[i], forked_sequence, logit_processor, true, false);
608597
}
609598
return forked_seq_ids;
610599
}
@@ -616,6 +605,8 @@ stop_sample_tokens(Sequence::Ptr running_sequence,
616605
size_t& max_removed_tokens_per_request) {
617606
running_sequence->remove_last_tokens(token_idx);
618607
max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_idx);
608+
running_sequence->set_status(SequenceStatus::FINISHED);
609+
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
619610
}
620611

621612
void
@@ -742,13 +733,16 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
742733
}
743734
// flag to add sampled token to generated sequence or extend logit processors only
744735
bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens,
745-
// flag to update generated length of sequence group in logit processor
746-
is_update_len_logit_processor = running_sequence_id == num_running_sequences - 1,
747736
is_validation_passed = true;
748737
if (is_validation_mode_enabled && !is_generate_n_tokens) {
749738
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token_id, is_extend_sequence, max_removed_tokens_per_request);
739+
// update log prob just while validation process
740+
if (!is_extend_sequence) {
741+
OPENVINO_ASSERT(generated_and_verified_len < running_sequences[running_sequence_id]->get_generated_len());
742+
running_sequence->update_generated_log_prob(generated_and_verified_len, sampled_token_id.m_log_prob);
743+
}
750744
}
751-
register_new_token(sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_update_len_logit_processor, is_validation_mode_enabled);
745+
register_new_token(sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_validation_mode_enabled);
752746
// to exit from sampling in case of failed token validation
753747
if (!is_validation_passed) {
754748
break;
@@ -794,6 +788,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
794788
align_all_sequence_len(sequence_group, min_generated_len, logit_processor);
795789
auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1;
796790
sequence_group->update_processed_tokens_num(min_processed_tokens);
791+
logit_processor.update_generated_len(min_processed_tokens);
797792
}
798793

799794
// accumulate a number of processed tokens

src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m
308308
request->pause_generation(true);
309309
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt) {
310310
request->pause_generation(true);
311+
} else if (request->get_num_processed_tokens() - request->get_prompt_len() + 1 >= sampling_params.max_new_tokens - 1) {
312+
request->pause_generation(true);
311313
}
312314
to_generate |= request->can_generate_tokens();
313315
}

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
181181

182182
auto draft_sampling_params = sampling_params[request_id];
183183
// set the parameters do not stop draft generation without stopping of the same request for main pipeline
184-
draft_sampling_params.max_new_tokens = SIZE_MAX - 1;
185-
draft_sampling_params.min_new_tokens = SIZE_MAX - 1;
184+
draft_sampling_params.max_new_tokens = draft_sampling_params.max_new_tokens + 1;
185+
draft_sampling_params.min_new_tokens = draft_sampling_params.min_new_tokens + 1;
186186
draft_sampling_params.ignore_eos = true;
187187
draft_generations.push_back(m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params));
188188
// decrease generation len to generate last token by main model

0 commit comments

Comments
 (0)