Skip to content

Commit 6f5d342

Browse files
committed
Remove paddings and add can_use_partial_preemption option for Scheduler
1 parent 5c439d2 commit 6f5d342

File tree

3 files changed

+191
-18
lines changed

3 files changed

+191
-18
lines changed

src/cpp/src/continuous_batching_pipeline.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ class ContinuousBatchingPipeline::Impl {
112112
updated_config.num_kv_blocks = device_config.get_num_kv_blocks();
113113
}
114114

115-
m_scheduler = std::make_shared<Scheduler>(updated_config);
115+
bool can_use_partial_preemption = true;
116+
if (device_config.get_device().find("GPU") != std::string::npos && !updated_config.dynamic_split_fuse) {
117+
// in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU,
118+
// as it may lead to performance slowdown
119+
can_use_partial_preemption = false;
120+
}
121+
122+
m_scheduler = std::make_shared<Scheduler>(updated_config, can_use_partial_preemption);
116123
// and finally create model runner
117124
m_model_runner = std::make_shared<ModelRunner>(infer_request, updated_config);
118125
m_sampler = std::make_shared<Sampler>(m_tokenizer);

src/cpp/src/scheduler.hpp

+20-16
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
#include <vector>
99

1010
#include "openvino/genai/scheduler_config.hpp"
11+
#include "device_config.hpp"
1112
#include "block_manager.hpp"
1213
#include "sequence_group.hpp"
1314

1415
namespace ov::genai {
1516
class Scheduler {
17+
bool m_can_use_partial_preemption;
18+
1619
SchedulerConfig m_config;
1720
BlockManager m_block_manager;
1821

@@ -32,8 +35,11 @@ class Scheduler {
3235
float m_cache_usage = 0.0;
3336
};
3437

35-
explicit Scheduler(const SchedulerConfig & config = {}) :
36-
m_config(config), m_block_manager(m_config.num_kv_blocks, m_config.enable_prefix_caching, m_config.block_size) { }
38+
explicit Scheduler(const SchedulerConfig & config = {}, bool can_use_partial_preemption = true) :
39+
m_can_use_partial_preemption(can_use_partial_preemption),
40+
m_config(config),
41+
m_block_manager(m_config.num_kv_blocks, m_config.enable_prefix_caching, m_config.block_size) {
42+
}
3743

3844
Output schedule(std::vector<SequenceGroup::Ptr>& sequence_groups) {
3945
Output scheduler_output;
@@ -47,7 +53,6 @@ class Scheduler {
4753
} else {
4854
// vLLM case
4955
// schedule prompt phase using whole prompt's input_ids
50-
// note, that we also apply padding, while need to be considered by model runner
5156

5257
_schedule_prompt_phase_vllm(sequence_groups, scheduler_output);
5358

@@ -105,7 +110,7 @@ class Scheduler {
105110
size_t preempted_tokens = 0;
106111
size_t num_blocks_occupied_by_sequence = m_block_manager.get_number_of_blocks_occupied_by_sequence(sequence_group);
107112

108-
if (num_blocks_occupied_by_sequence <= blocks_needed) {
113+
if (num_blocks_occupied_by_sequence <= blocks_needed || !m_can_use_partial_preemption) {
109114
auto sequences = sequence_group->get_not_finished_sequences();
110115
for (size_t s = 0; s < sequences.size(); ++s) {
111116
auto seq_id = sequences[s]->get_id();
@@ -115,7 +120,7 @@ class Scheduler {
115120
sequence_group->set_waiting();
116121
return m_block_manager.num_free_blocks() > prev_blocks_count;
117122
}
118-
123+
119124
size_t logical_blocks_released;
120125
if (sequence_group->get_sampling_parameters().is_beam_search()) {
121126
logical_blocks_released = m_block_manager.free_partially_beam_search_group(sequence_group, blocks_needed);
@@ -126,7 +131,7 @@ class Scheduler {
126131

127132
// calculate the number of preempted tokens
128133
auto tokens_in_last_block = processed_tokens % block_size;
129-
if (tokens_in_last_block == 0) {
134+
if (tokens_in_last_block == 0) {
130135
tokens_in_last_block = block_size;
131136
}
132137
preempted_tokens = tokens_in_last_block + std::max<size_t>((int)logical_blocks_released - 1, 0) * block_size;
@@ -166,7 +171,7 @@ class Scheduler {
166171
while (!m_block_manager.can_append_slots(sequence_group)) {
167172
// let's run a sequence for eviction
168173
size_t evicted_sequence_group_id = _get_low_priority_sequence_group_id(sequence_groups);
169-
174+
170175
if (evicted_sequence_group_id <= sequence_group_id) {
171176
// we have a cycle when current group need to evict itself to be in a running state
172177
break;
@@ -265,7 +270,7 @@ class Scheduler {
265270
sequence_group->clear_scheduled_tokens();
266271
continue;
267272
}
268-
273+
269274
// allocate new slots
270275
std::map<size_t, std::list<size_t>> copy_blocks_map = m_block_manager.append_slots(sequence_group);
271276

@@ -311,19 +316,19 @@ class Scheduler {
311316
// TODO: it currently does not handle beam search, where beam width should contribute to total number of "num running sequences"
312317
size_t num_running_sequence_groups = _num_running_sequence_groups(sequence_groups);
313318

314-
for (size_t sequence_group_id = 0, num_scheduled_tokens = 0, max_sequence_len = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
319+
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
315320
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
316-
if (!sequence_group->can_generate_tokens() && !sequence_group->is_waiting()) {
321+
const bool recompute_evicted_sequences = sequence_group->get_num_processed_tokens() == 0 && !m_can_use_partial_preemption;
322+
if ((!sequence_group->can_generate_tokens() || recompute_evicted_sequences) && !sequence_group->is_waiting()) {
317323
size_t num_running_seqs = sequence_group->num_running_seqs();
318324
// prompt phases can have a single running sequence
319325
OPENVINO_ASSERT(num_running_seqs == 1);
320326
// here we also assume that sequence must be scheduler in a single shot and has no already generated context
321327
if (!m_config.enable_prefix_caching)
322328
OPENVINO_ASSERT(sequence_group->get_context_len() == 0);
323329

324-
int64_t num_available_tokens_in_megabatch = m_config.max_num_batched_tokens - scheduler_output.m_total_num_scheduled_tokens;
330+
size_t num_available_tokens_in_megabatch = m_config.max_num_batched_tokens - scheduler_output.m_total_num_scheduled_tokens;
325331
size_t sequence_len = sequence_group->get_num_available_tokens_for_batching();
326-
max_sequence_len = std::max(max_sequence_len, sequence_len);
327332

328333
// TODO: better handling
329334
// e.g. return status that sequence is ignored and cannot be processed by current scheduling algorigthm
@@ -334,7 +339,7 @@ class Scheduler {
334339
break;
335340

336341
// apply max num batched tokens limitation
337-
if (num_available_tokens_in_megabatch < static_cast<int64_t>(max_sequence_len))
342+
if (num_available_tokens_in_megabatch < sequence_len)
338343
break;
339344

340345
// apply KV cache limitations
@@ -357,21 +362,20 @@ class Scheduler {
357362
{
358363
scheduler_output.m_scheduled_sequence_groups_ids.push_back(sequence_group_id);
359364
scheduler_output.m_block_tables[seq_id] = m_block_manager.get_block_table(seq_id);
360-
scheduler_output.m_total_num_scheduled_tokens = max_sequence_len * scheduler_output.m_scheduled_sequence_groups_ids.size();
365+
scheduler_output.m_total_num_scheduled_tokens += sequence_len;
361366
}
362367

363368
// update "is_prompt" flag
364369
scheduler_output.is_prompt = true;
365370
}
366371

367-
num_scheduled_tokens += sequence_len;
368372
num_running_sequence_groups += 1;
369373
}
370374
}
371375
}
372376

373377
void _clear_waiting_sequences(const std::vector<SequenceGroup::Ptr>& sequence_groups) {
374-
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
378+
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
375379
sequence_groups[sequence_group_id]->clear_waiting_sequences();
376380
}
377381
}

tests/cpp/scheduler.cpp

+163-1
Original file line numberDiff line numberDiff line change
@@ -675,4 +675,166 @@ TEST(TestScheduler, prefix_caching_with_max_new_tokens_equal_1) {
675675
}
676676
}
677677

678-
}
678+
}
679+
680+
TEST(TestScheduler, test_partially_preempted_prompt_not_allowed) {
681+
SchedulerConfig scheduler_config;
682+
scheduler_config.max_num_batched_tokens = 32;
683+
scheduler_config.num_kv_blocks = 6;
684+
scheduler_config.block_size = 4;
685+
scheduler_config.dynamic_split_fuse = false;
686+
scheduler_config.max_num_seqs = 5;
687+
688+
std::vector<uint64_t> tokens = {0,1,2,3,4,5,6,7,8,9,10,11};
689+
SequenceGroup::Ptr sequence_group1 = std::make_shared<SequenceGroup>(0, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()),
690+
ov::genai::greedy(), scheduler_config.block_size, scheduler_config.enable_prefix_caching);
691+
auto idx0 = (*sequence_group1)[0]->get_id();
692+
SequenceGroup::Ptr sequence_group2 = std::make_shared<SequenceGroup>(1, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()),
693+
ov::genai::greedy(), scheduler_config.block_size, scheduler_config.enable_prefix_caching);
694+
auto idx1 = (*sequence_group2)[0]->get_id();
695+
std::vector<SequenceGroup::Ptr> requests = {sequence_group1, sequence_group2};
696+
697+
698+
// schedule 2 sequence groups that use all available 2*3 kv blocks, we used all available kv-blocks.
699+
const bool can_use_partial_preemption = false;
700+
Scheduler scheduler = Scheduler(scheduler_config, can_use_partial_preemption);
701+
auto out1 = scheduler.schedule(requests);
702+
703+
for (auto req : requests)
704+
req->finish_iteration();
705+
706+
// sequence_group2 should be fully preempted
707+
auto out2 = scheduler.schedule(requests);
708+
709+
// check that sequence_group1 has one more allocated block
710+
auto block_table1 = scheduler.get_block_table(*(*sequence_group1)[0]);
711+
ASSERT_EQ(block_table1.size(), 4);
712+
ASSERT_EQ(block_table1[0]->get_index(), 0);
713+
ASSERT_EQ(block_table1[1]->get_index(), 1);
714+
ASSERT_EQ(block_table1[2]->get_index(), 2);
715+
ASSERT_EQ(block_table1[3]->get_index(), 3);
716+
ASSERT_EQ(out2.m_block_tables[idx0].size(), 4);
717+
ASSERT_EQ(out2.m_block_tables[idx0][0]->get_index(), 0);
718+
ASSERT_EQ(out2.m_block_tables[idx0][1]->get_index(), 1);
719+
ASSERT_EQ(out2.m_block_tables[idx0][2]->get_index(), 2);
720+
ASSERT_EQ(out2.m_block_tables[idx0][3]->get_index(), 3);
721+
722+
std::vector<uint64_t> ref_ids = {0};
723+
ASSERT_EQ(out2.m_scheduled_sequence_groups_ids, ref_ids);
724+
ASSERT_EQ(out2.m_total_num_scheduled_tokens, 1);
725+
726+
// for vllm case sequence_group2 is fully preempted
727+
EXPECT_FALSE(scheduler.has_block_table(idx1));
728+
729+
for (auto req : requests)
730+
req->finish_iteration();
731+
732+
// finish first sequence
733+
requests[0]->get_running_sequences()[0]->set_status(SequenceStatus::FINISHED);
734+
scheduler.free_sequence(idx0);
735+
clear_finished_sequences(requests);
736+
737+
// sequence_group2 should be scheduled
738+
auto out3 = scheduler.schedule(requests);
739+
740+
// prompt should be fully scheduled
741+
ASSERT_EQ(out3.m_total_num_scheduled_tokens, 12);
742+
743+
ASSERT_EQ(out3.m_block_tables[idx1][0]->get_index(), 4);
744+
ASSERT_EQ(out3.m_block_tables[idx1][1]->get_index(), 5);
745+
ASSERT_EQ(out3.m_block_tables[idx1][2]->get_index(), 0);
746+
747+
auto block_table2 = scheduler.get_block_table(*(*sequence_group2)[0]);
748+
ASSERT_EQ(block_table2.size(), 3);
749+
ASSERT_EQ(block_table2[0]->get_index(), 4);
750+
ASSERT_EQ(block_table2[1]->get_index(), 5);
751+
ASSERT_EQ(block_table2[2]->get_index(), 0);
752+
753+
EXPECT_FALSE(scheduler.has_block_table(idx0));
754+
}
755+
756+
TEST(TestScheduler, test_partially_preempted_prompt_not_allowed2) {
757+
SchedulerConfig scheduler_config;
758+
scheduler_config.max_num_batched_tokens = 32;
759+
scheduler_config.num_kv_blocks = 6;
760+
scheduler_config.block_size = 4;
761+
scheduler_config.dynamic_split_fuse = false;
762+
scheduler_config.max_num_seqs = 5;
763+
764+
std::vector<uint64_t> tokens = {0,1,2,3,4,5,6,7,8,9};
765+
SequenceGroup::Ptr sequence_group1 = std::make_shared<SequenceGroup>(0, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()),
766+
ov::genai::greedy(), scheduler_config.block_size, scheduler_config.enable_prefix_caching);
767+
auto idx0 = (*sequence_group1)[0]->get_id();
768+
SequenceGroup::Ptr sequence_group2 = std::make_shared<SequenceGroup>(1, ov::Tensor(ov::element::i64, {tokens.size()}, tokens.data()),
769+
ov::genai::greedy(), scheduler_config.block_size, scheduler_config.enable_prefix_caching);
770+
auto idx1 = (*sequence_group2)[0]->get_id();
771+
std::vector<SequenceGroup::Ptr> requests = {sequence_group1, sequence_group2};
772+
773+
// schedule 2 sequence groups that use all available 2*3 kv blocks, we used all available kv-blocks.
774+
const bool can_use_partial_preemption = false;
775+
Scheduler scheduler = Scheduler(scheduler_config, can_use_partial_preemption);
776+
scheduler.schedule(requests);
777+
for (auto req: requests)
778+
req->finish_iteration();
779+
780+
scheduler.schedule(requests);
781+
for (auto req: requests)
782+
req->finish_iteration();
783+
784+
scheduler.schedule(requests);
785+
for (auto req: requests)
786+
req->finish_iteration();
787+
788+
// sequence_group2 should be fully preempted
789+
scheduler.schedule(requests);
790+
for (auto req: requests)
791+
req->finish_iteration();
792+
793+
auto out2 = scheduler.schedule(requests);
794+
795+
// check that sequence_group1 has one more allocated block
796+
auto block_table1 = scheduler.get_block_table(*(*sequence_group1)[0]);
797+
ASSERT_EQ(block_table1.size(), 4);
798+
ASSERT_EQ(block_table1[0]->get_index(), 0);
799+
ASSERT_EQ(block_table1[1]->get_index(), 1);
800+
ASSERT_EQ(block_table1[2]->get_index(), 2);
801+
ASSERT_EQ(block_table1[3]->get_index(), 3);
802+
ASSERT_EQ(out2.m_block_tables[idx0].size(), 4);
803+
ASSERT_EQ(out2.m_block_tables[idx0][0]->get_index(), 0);
804+
ASSERT_EQ(out2.m_block_tables[idx0][1]->get_index(), 1);
805+
ASSERT_EQ(out2.m_block_tables[idx0][2]->get_index(), 2);
806+
ASSERT_EQ(out2.m_block_tables[idx0][3]->get_index(), 3);
807+
808+
std::vector<uint64_t> ref_ids = {0};
809+
ASSERT_EQ(out2.m_scheduled_sequence_groups_ids, ref_ids);
810+
ASSERT_EQ(out2.m_total_num_scheduled_tokens, 1);
811+
812+
// for vllm case sequence_group2 is fully preempted
813+
EXPECT_FALSE(scheduler.has_block_table(idx1));
814+
815+
for (auto req: requests)
816+
req->finish_iteration();
817+
818+
// finish first sequence
819+
requests[0]->get_running_sequences()[0]->set_status(SequenceStatus::FINISHED);
820+
scheduler.free_sequence(idx0);
821+
clear_finished_sequences(requests);
822+
823+
// sequence_group2 should be scheduled
824+
auto out3 = scheduler.schedule(requests);
825+
826+
// prompt should be fully scheduled + generated tokens concatenated to prompt (10 + 2)
827+
ASSERT_EQ(out3.m_total_num_scheduled_tokens, 12);
828+
829+
ASSERT_EQ(out3.m_block_tables[idx1][0]->get_index(), 4);
830+
ASSERT_EQ(out3.m_block_tables[idx1][1]->get_index(), 5);
831+
ASSERT_EQ(out3.m_block_tables[idx1][2]->get_index(), 0);
832+
833+
auto block_table2 = scheduler.get_block_table(*(*sequence_group2)[0]);
834+
ASSERT_EQ(block_table2.size(), 3);
835+
ASSERT_EQ(block_table2[0]->get_index(), 4);
836+
ASSERT_EQ(block_table2[1]->get_index(), 5);
837+
ASSERT_EQ(block_table2[2]->get_index(), 0);
838+
839+
EXPECT_FALSE(scheduler.has_block_table(idx0));
840+
}

0 commit comments

Comments
 (0)