@@ -565,23 +565,12 @@ void register_new_token(const Token& sampled_token_id,
565
565
Sequence::Ptr running_sequence,
566
566
LogitProcessor& logit_processor,
567
567
bool is_extend_sequence,
568
- bool is_update_len_logit_processor,
569
568
bool is_validation_mode_enabled) {
570
569
logit_processor.register_new_generated_token (sampled_token_id.m_index );
571
- size_t generated_len = logit_processor.get_generated_len ();
572
570
if (is_extend_sequence) {
573
571
running_sequence->append_token (sampled_token_id.m_index , sampled_token_id.m_log_prob );
574
- } else {
575
- // just update the token log prob in case of successfully validated token
576
- OPENVINO_ASSERT (generated_len < running_sequence->get_generated_len ());
577
- running_sequence->update_generated_log_prob (generated_len, sampled_token_id.m_log_prob );
578
- }
579
- // increment seq len only for one sequence in sequence group to sync them
580
- if (is_update_len_logit_processor) {
581
- logit_processor.update_generated_len (++generated_len);
582
572
}
583
573
if (!is_validation_mode_enabled &&
584
- logit_processor.is_dynamic_speculative_decoding () &&
585
574
std::fabs (sampled_token_id.m_log_prob ) < logit_processor.get_assistant_confidence_threshold ()) {
586
575
auto sequence_group = running_sequence->get_sequence_group_ptr ();
587
576
sequence_group->pause_generation (true );
@@ -604,7 +593,7 @@ create_n_forked_sequences(SequenceGroup::Ptr sequence_group,
604
593
const auto forked_sequence = sequence_group->fork_sequence (sequence_to_fork);
605
594
const auto forked_seq_id = forked_sequence->get_id ();
606
595
forked_seq_ids.push_back (forked_seq_id);
607
- register_new_token (sampled_tokens[i], forked_sequence, logit_processor, true , false , false );
596
+ register_new_token (sampled_tokens[i], forked_sequence, logit_processor, true , false );
608
597
}
609
598
return forked_seq_ids;
610
599
}
@@ -616,6 +605,8 @@ stop_sample_tokens(Sequence::Ptr running_sequence,
616
605
size_t & max_removed_tokens_per_request) {
617
606
running_sequence->remove_last_tokens (token_idx);
618
607
max_removed_tokens_per_request = std::max (max_removed_tokens_per_request, token_idx);
608
+ running_sequence->set_status (SequenceStatus::FINISHED);
609
+ running_sequence->set_finish_reason (GenerationFinishReason::STOP);
619
610
}
620
611
621
612
void
@@ -742,13 +733,16 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
742
733
}
743
734
// flag to add sampled token to generated sequence or extend logit processors only
744
735
bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens,
745
- // flag to update generated length of sequence group in logit processor
746
- is_update_len_logit_processor = running_sequence_id == num_running_sequences - 1 ,
747
736
is_validation_passed = true ;
748
737
if (is_validation_mode_enabled && !is_generate_n_tokens) {
749
738
is_validation_passed = validate_candidate (running_sequences[running_sequence_id], token_offset, sampled_token_id, is_extend_sequence, max_removed_tokens_per_request);
739
+ // update log prob just while validation process
740
+ if (!is_extend_sequence) {
741
+ OPENVINO_ASSERT (generated_and_verified_len < running_sequences[running_sequence_id]->get_generated_len ());
742
+ running_sequence->update_generated_log_prob (generated_and_verified_len, sampled_token_id.m_log_prob );
743
+ }
750
744
}
751
- register_new_token (sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_update_len_logit_processor, is_validation_mode_enabled);
745
+ register_new_token (sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_validation_mode_enabled);
752
746
// to exit from sampling in case of failed token validation
753
747
if (!is_validation_passed) {
754
748
break ;
@@ -794,6 +788,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
794
788
align_all_sequence_len (sequence_group, min_generated_len, logit_processor);
795
789
auto min_processed_tokens = sequence_group->get_prompt_len () + min_generated_len - 1 ;
796
790
sequence_group->update_processed_tokens_num (min_processed_tokens);
791
+ logit_processor.update_generated_len (min_processed_tokens);
797
792
}
798
793
799
794
// accumulate a number of processed tokens
0 commit comments