@@ -675,4 +675,166 @@ TEST(TestScheduler, prefix_caching_with_max_new_tokens_equal_1) {
675
675
}
676
676
}
677
677
678
- }
678
+ }
679
+
680
+ TEST (TestScheduler, test_partially_preempted_prompt_not_allowed) {
681
+ SchedulerConfig scheduler_config;
682
+ scheduler_config.max_num_batched_tokens = 32 ;
683
+ scheduler_config.num_kv_blocks = 6 ;
684
+ scheduler_config.block_size = 4 ;
685
+ scheduler_config.dynamic_split_fuse = false ;
686
+ scheduler_config.max_num_seqs = 5 ;
687
+
688
+ std::vector<uint64_t > tokens = {0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ,10 ,11 };
689
+ SequenceGroup::Ptr sequence_group1 = std::make_shared<SequenceGroup>(0 , ov::Tensor (ov::element::i64, {tokens.size ()}, tokens.data ()),
690
+ ov::genai::greedy (), scheduler_config.block_size , scheduler_config.enable_prefix_caching );
691
+ auto idx0 = (*sequence_group1)[0 ]->get_id ();
692
+ SequenceGroup::Ptr sequence_group2 = std::make_shared<SequenceGroup>(1 , ov::Tensor (ov::element::i64, {tokens.size ()}, tokens.data ()),
693
+ ov::genai::greedy (), scheduler_config.block_size , scheduler_config.enable_prefix_caching );
694
+ auto idx1 = (*sequence_group2)[0 ]->get_id ();
695
+ std::vector<SequenceGroup::Ptr > requests = {sequence_group1, sequence_group2};
696
+
697
+
698
+ // schedule 2 sequence groups that use all available 2*3 kv blocks, we used all available kv-blocks.
699
+ const bool can_use_partial_preemption = false ;
700
+ Scheduler scheduler = Scheduler (scheduler_config, can_use_partial_preemption);
701
+ auto out1 = scheduler.schedule (requests);
702
+
703
+ for (auto req : requests)
704
+ req->finish_iteration ();
705
+
706
+ // sequence_group2 should be fully preempted
707
+ auto out2 = scheduler.schedule (requests);
708
+
709
+ // check that sequence_group1 has one more allocated block
710
+ auto block_table1 = scheduler.get_block_table (*(*sequence_group1)[0 ]);
711
+ ASSERT_EQ (block_table1.size (), 4 );
712
+ ASSERT_EQ (block_table1[0 ]->get_index (), 0 );
713
+ ASSERT_EQ (block_table1[1 ]->get_index (), 1 );
714
+ ASSERT_EQ (block_table1[2 ]->get_index (), 2 );
715
+ ASSERT_EQ (block_table1[3 ]->get_index (), 3 );
716
+ ASSERT_EQ (out2.m_block_tables [idx0].size (), 4 );
717
+ ASSERT_EQ (out2.m_block_tables [idx0][0 ]->get_index (), 0 );
718
+ ASSERT_EQ (out2.m_block_tables [idx0][1 ]->get_index (), 1 );
719
+ ASSERT_EQ (out2.m_block_tables [idx0][2 ]->get_index (), 2 );
720
+ ASSERT_EQ (out2.m_block_tables [idx0][3 ]->get_index (), 3 );
721
+
722
+ std::vector<uint64_t > ref_ids = {0 };
723
+ ASSERT_EQ (out2.m_scheduled_sequence_groups_ids , ref_ids);
724
+ ASSERT_EQ (out2.m_total_num_scheduled_tokens , 1 );
725
+
726
+ // for vllm case sequence_group2 is fully preempted
727
+ EXPECT_FALSE (scheduler.has_block_table (idx1));
728
+
729
+ for (auto req : requests)
730
+ req->finish_iteration ();
731
+
732
+ // finish first sequence
733
+ requests[0 ]->get_running_sequences ()[0 ]->set_status (SequenceStatus::FINISHED);
734
+ scheduler.free_sequence (idx0);
735
+ clear_finished_sequences (requests);
736
+
737
+ // sequence_group2 should be scheduled
738
+ auto out3 = scheduler.schedule (requests);
739
+
740
+ // prompt should be fully scheduled
741
+ ASSERT_EQ (out3.m_total_num_scheduled_tokens , 12 );
742
+
743
+ ASSERT_EQ (out3.m_block_tables [idx1][0 ]->get_index (), 4 );
744
+ ASSERT_EQ (out3.m_block_tables [idx1][1 ]->get_index (), 5 );
745
+ ASSERT_EQ (out3.m_block_tables [idx1][2 ]->get_index (), 0 );
746
+
747
+ auto block_table2 = scheduler.get_block_table (*(*sequence_group2)[0 ]);
748
+ ASSERT_EQ (block_table2.size (), 3 );
749
+ ASSERT_EQ (block_table2[0 ]->get_index (), 4 );
750
+ ASSERT_EQ (block_table2[1 ]->get_index (), 5 );
751
+ ASSERT_EQ (block_table2[2 ]->get_index (), 0 );
752
+
753
+ EXPECT_FALSE (scheduler.has_block_table (idx0));
754
+ }
755
+
756
+ TEST (TestScheduler, test_partially_preempted_prompt_not_allowed2) {
757
+ SchedulerConfig scheduler_config;
758
+ scheduler_config.max_num_batched_tokens = 32 ;
759
+ scheduler_config.num_kv_blocks = 6 ;
760
+ scheduler_config.block_size = 4 ;
761
+ scheduler_config.dynamic_split_fuse = false ;
762
+ scheduler_config.max_num_seqs = 5 ;
763
+
764
+ std::vector<uint64_t > tokens = {0 ,1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 };
765
+ SequenceGroup::Ptr sequence_group1 = std::make_shared<SequenceGroup>(0 , ov::Tensor (ov::element::i64, {tokens.size ()}, tokens.data ()),
766
+ ov::genai::greedy (), scheduler_config.block_size , scheduler_config.enable_prefix_caching );
767
+ auto idx0 = (*sequence_group1)[0 ]->get_id ();
768
+ SequenceGroup::Ptr sequence_group2 = std::make_shared<SequenceGroup>(1 , ov::Tensor (ov::element::i64, {tokens.size ()}, tokens.data ()),
769
+ ov::genai::greedy (), scheduler_config.block_size , scheduler_config.enable_prefix_caching );
770
+ auto idx1 = (*sequence_group2)[0 ]->get_id ();
771
+ std::vector<SequenceGroup::Ptr > requests = {sequence_group1, sequence_group2};
772
+
773
+ // schedule 2 sequence groups that use all available 2*3 kv blocks, we used all available kv-blocks.
774
+ const bool can_use_partial_preemption = false ;
775
+ Scheduler scheduler = Scheduler (scheduler_config, can_use_partial_preemption);
776
+ scheduler.schedule (requests);
777
+ for (auto req: requests)
778
+ req->finish_iteration ();
779
+
780
+ scheduler.schedule (requests);
781
+ for (auto req: requests)
782
+ req->finish_iteration ();
783
+
784
+ scheduler.schedule (requests);
785
+ for (auto req: requests)
786
+ req->finish_iteration ();
787
+
788
+ // sequence_group2 should be fully preempted
789
+ scheduler.schedule (requests);
790
+ for (auto req: requests)
791
+ req->finish_iteration ();
792
+
793
+ auto out2 = scheduler.schedule (requests);
794
+
795
+ // check that sequence_group1 has one more allocated block
796
+ auto block_table1 = scheduler.get_block_table (*(*sequence_group1)[0 ]);
797
+ ASSERT_EQ (block_table1.size (), 4 );
798
+ ASSERT_EQ (block_table1[0 ]->get_index (), 0 );
799
+ ASSERT_EQ (block_table1[1 ]->get_index (), 1 );
800
+ ASSERT_EQ (block_table1[2 ]->get_index (), 2 );
801
+ ASSERT_EQ (block_table1[3 ]->get_index (), 3 );
802
+ ASSERT_EQ (out2.m_block_tables [idx0].size (), 4 );
803
+ ASSERT_EQ (out2.m_block_tables [idx0][0 ]->get_index (), 0 );
804
+ ASSERT_EQ (out2.m_block_tables [idx0][1 ]->get_index (), 1 );
805
+ ASSERT_EQ (out2.m_block_tables [idx0][2 ]->get_index (), 2 );
806
+ ASSERT_EQ (out2.m_block_tables [idx0][3 ]->get_index (), 3 );
807
+
808
+ std::vector<uint64_t > ref_ids = {0 };
809
+ ASSERT_EQ (out2.m_scheduled_sequence_groups_ids , ref_ids);
810
+ ASSERT_EQ (out2.m_total_num_scheduled_tokens , 1 );
811
+
812
+ // for vllm case sequence_group2 is fully preempted
813
+ EXPECT_FALSE (scheduler.has_block_table (idx1));
814
+
815
+ for (auto req: requests)
816
+ req->finish_iteration ();
817
+
818
+ // finish first sequence
819
+ requests[0 ]->get_running_sequences ()[0 ]->set_status (SequenceStatus::FINISHED);
820
+ scheduler.free_sequence (idx0);
821
+ clear_finished_sequences (requests);
822
+
823
+ // sequence_group2 should be scheduled
824
+ auto out3 = scheduler.schedule (requests);
825
+
826
+ // prompt should be fully scheduled + generated tokens concatenated to prompt (10 + 2)
827
+ ASSERT_EQ (out3.m_total_num_scheduled_tokens , 12 );
828
+
829
+ ASSERT_EQ (out3.m_block_tables [idx1][0 ]->get_index (), 4 );
830
+ ASSERT_EQ (out3.m_block_tables [idx1][1 ]->get_index (), 5 );
831
+ ASSERT_EQ (out3.m_block_tables [idx1][2 ]->get_index (), 0 );
832
+
833
+ auto block_table2 = scheduler.get_block_table (*(*sequence_group2)[0 ]);
834
+ ASSERT_EQ (block_table2.size (), 3 );
835
+ ASSERT_EQ (block_table2[0 ]->get_index (), 4 );
836
+ ASSERT_EQ (block_table2[1 ]->get_index (), 5 );
837
+ ASSERT_EQ (block_table2[2 ]->get_index (), 0 );
838
+
839
+ EXPECT_FALSE (scheduler.has_block_table (idx0));
840
+ }
0 commit comments