-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove tokens after EOS for draft model for speculative decoding #1951
Conversation
@@ -337,5 +337,16 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m | |||
to_generate |= request->can_generate_tokens(); | |||
} | |||
} | |||
|
|||
for (auto& request : m_requests) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose we initially ignore EOS tokens for draft models, why are they removed here? It should not affect results of main model, should they?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was decided to not add part after EOS for draft model according to this ticket https://jira.devtools.intel.com/browse/CVS-164477 . It affects results of main model. What I saw is that if we have a stop_token, the generation result can contain it and some tokens after it, with these changes it will be nothing after stop token
Have discussed offline how to implement in best way |
@sbalandi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Could you please fix one comment?
@@ -851,6 +853,12 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr | |||
// to exit from sampling in case of failed token validation | |||
if (!is_validation_passed) { | |||
break; | |||
} else { | |||
auto sampling_params = sequence_group->get_sampling_parameters(); | |||
if (is_stop_token_id_hit(sampled_token.m_index, sampling_params.stop_token_ids) && !sampling_params.ignore_eos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, looks like is_stop_token_id_hit
is equal to simple find
:D
CVS-164477