@@ -778,6 +778,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
778
778
del model_with_cache
779
779
gc .collect ()
780
780
781
+ def test_beam_search (self ):
782
+ model_id = MODEL_NAMES ["llama" ]
783
+ ov_model_stateful = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = True )
784
+ ov_model_stateless = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = False )
785
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id )
786
+
787
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
788
+ tokenizer .pad_token = tokenizer .eos_token
789
+ tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
790
+ ov_model_stateful .generation_config .eos_token_id = None
791
+ ov_model_stateless .generation_config .eos_token_id = None
792
+ transformers_model .generation_config .eos_token_id = None
793
+ ov_model_stateful .config .eos_token_id = None
794
+ ov_model_stateless .config .eos_token_id = None
795
+ transformers_model .config .eos_token_id = None
796
+
797
+ # beam search
798
+ gen_config = GenerationConfig (
799
+ max_new_tokens = 10 ,
800
+ min_new_tokens = 10 ,
801
+ num_beams = 4 ,
802
+ do_sample = False ,
803
+ eos_token_id = None ,
804
+ )
805
+
806
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
807
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
808
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
809
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
810
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
811
+ # beam sample
812
+ gen_config = GenerationConfig (
813
+ max_new_tokens = 10 ,
814
+ min_new_tokens = 10 ,
815
+ num_beams = 4 ,
816
+ do_sample = True ,
817
+ eos_token_id = None ,
818
+ top_k = 1 ,
819
+ )
820
+
821
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
822
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
823
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
824
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
825
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
826
+
827
+ # group beam search
828
+ gen_config = GenerationConfig (
829
+ max_new_tokens = 10 ,
830
+ min_new_tokens = 10 ,
831
+ num_beams = 4 ,
832
+ do_sample = False ,
833
+ eos_token_id = None ,
834
+ num_beam_groups = 2 ,
835
+ diversity_penalty = 0.0000001 ,
836
+ )
837
+
838
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
839
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
840
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
841
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
842
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
843
+
844
+ # constrained beam search
845
+ force_word = "cat"
846
+ force_words_ids = [tokenizer ([force_word ], add_special_tokens = False ).input_ids ]
847
+ gen_config = GenerationConfig (
848
+ max_new_tokens = 10 ,
849
+ min_new_tokens = 10 ,
850
+ num_beams = 4 ,
851
+ do_sample = False ,
852
+ eos_token_id = None ,
853
+ force_words_ids = force_words_ids ,
854
+ )
855
+
856
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
857
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
858
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
859
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
860
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
861
+
781
862
782
863
class OVModelForMaskedLMIntegrationTest (unittest .TestCase ):
783
864
SUPPORTED_ARCHITECTURES = (
0 commit comments