Skip to content

Commit cfd220e

Browse files
authored
[Sampler] Fix stop strings offset for speculative decoding (#1719)
Fix of broken in #1676 behavior. In case of speculative decoding we should match step_substring: ![image](https://github.com/user-attachments/assets/99fb9e64-37c9-4704-a90e-82e8a74baaaa)
1 parent 81cd23e commit cfd220e

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

src/cpp/src/sampler.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,16 @@ struct MatchStopStringResult {
107107
MatchStopStringResult match_stop_string(Tokenizer& tokenizer,
108108
const TokenIds& generated_tokens,
109109
const std::pair<size_t, std::set<std::string>>& stop_strings,
110-
bool is_include_to_output) {
110+
bool is_include_to_output,
111+
size_t draft_generated_tokens = 0) {
111112
MatchStopStringResult result;
112113
if (generated_tokens.size() >= stop_strings.first) {
113-
size_t offset = generated_tokens.size() - stop_strings.first;
114+
// draft_generated_tokens is to handle case with >= 1 generated tokens per step
115+
size_t offset = generated_tokens.size() - draft_generated_tokens;
116+
if (offset < stop_strings.first) {
117+
return result;
118+
}
119+
offset -= stop_strings.first;
114120
TokenIds buffer(generated_tokens.begin() + offset, generated_tokens.end());
115121
std::string decoded_buffer = tokenizer.decode(buffer);
116122
for (const auto& stop_string : stop_strings.second) {
@@ -567,7 +573,8 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
567573

568574
if (!sampling_params.stop_strings.empty()) {
569575
auto& stop_strings = m_stop_strings.at(sequence_group->get_request_id());
570-
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings, sampling_params.include_stop_str_in_output);
576+
auto match_result = match_stop_string(m_tokenizer, running_sequence->get_generated_ids(), stop_strings,
577+
sampling_params.include_stop_str_in_output, sequence_group->get_num_tokens_to_validate());
571578
if (match_result.is_matched) {
572579
running_sequence->remove_last_tokens(match_result.to_remove);
573580

tests/python_tests/test_continuous_batching.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -376,30 +376,32 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type):
376376
model_id : str = "facebook/opt-125m"
377377
opt_model, hf_tokenizer = get_hugging_face_models(model_id)
378378

379-
models_path : Path = tmp_path / "t_streaming" / model_id
379+
models_path : Path = tmp_path / model_id
380380
convert_models(opt_model, hf_tokenizer, models_path)
381381

382382
generation_config = GenerationConfig()
383-
pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
383+
pipe, input, generation_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
384384

385+
it_cnt = 0
385386
def py_streamer(py_str: str):
387+
nonlocal it_cnt
388+
it_cnt += 1
386389
return False
387390

388-
try:
389-
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
390-
except Exception:
391-
assert True
391+
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
392392

393393
del pipe
394394
rmtree(models_path)
395395

396+
assert it_cnt > 0
397+
396398
@pytest.mark.parametrize("pipeline_type", ["continuous_batching", "speculative_decoding", "prompt_lookup_decoding", "llm_pipeline"])
397399
@pytest.mark.precommit
398400
def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type):
399401
model_id : str = "facebook/opt-125m"
400402
opt_model, hf_tokenizer = get_hugging_face_models(model_id)
401403

402-
models_path : Path = tmp_path / "t_streaming" / model_id
404+
models_path : Path = tmp_path / model_id
403405
convert_models(opt_model, hf_tokenizer, models_path)
404406

405407
generation_config = GenerationConfig()
@@ -408,13 +410,15 @@ def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type)
408410

409411
pipe, input, generation_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
410412

413+
it_cnt = 0
411414
def py_streamer(py_str: str):
412-
raise Exception("Streamer was called")
415+
nonlocal it_cnt
416+
it_cnt += 1
417+
return False
413418

414-
try:
415-
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
416-
except Exception:
417-
assert False
419+
_ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer)
418420

419421
del pipe
420422
rmtree(models_path)
423+
424+
assert it_cnt == 0

0 commit comments

Comments
 (0)