@@ -261,7 +261,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update
261
261
const size_t num_processed_tokens = request->get_num_processed_tokens (),
262
262
prompt_len = request->get_prompt_len (),
263
263
updated_context_len = min_candidate_len + prompt_len,
264
- max_new_tokens = request->get_sampling_parameters ().max_new_tokens ;
264
+ max_new_tokens = request->get_sampling_parameters ().get_max_new_tokens (request-> get_prompt_len ()) ;
265
265
size_t generated_len = request->get_context_len () >= request->get_prompt_len () ? request->get_context_len () - request->get_prompt_len () + 1 : 0 ;
266
266
if (generated_len > 0 && result.removed_tokens_cnt > 0 ) {
267
267
request->update_processed_tokens_num (num_processed_tokens - result.removed_tokens_cnt + 1 );
@@ -324,13 +324,13 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m
324
324
// generate only one token in case of non speculative decoding
325
325
request->pause_generation (true );
326
326
} else if (request->get_num_processed_tokens () >= request->get_prompt_len () &&
327
- (request->get_num_processed_tokens () - request->get_prompt_len () + 1 ) >= sampling_params.max_new_tokens - 1 ) {
327
+ (request->get_num_processed_tokens () - request->get_prompt_len () + 1 ) >= sampling_params.get_max_new_tokens (request-> get_prompt_len ()) - 1 ) {
328
328
request->pause_generation (true );
329
329
} else if (request->get_num_processed_tokens () == 0 && sampling_params.num_return_sequences > 1 ) {
330
330
request->pause_generation (true );
331
331
} else if (sampling_params.num_assistant_tokens <= generated_tokens_cnt && sampling_params.assistant_confidence_threshold == 0 .f ) {
332
332
request->pause_generation (true );
333
- } else if (sampling_params.max_new_tokens == 0 ) {
333
+ } else if (sampling_params.get_max_new_tokens (request-> get_prompt_len ()) == 0 ) {
334
334
request->pause_generation (true );
335
335
} else if (request->get_num_processed_tokens () == request->get_prompt_len ()) {
336
336
request->pause_generation (true );
0 commit comments