@@ -757,6 +757,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
757
757
del model_with_cache
758
758
gc .collect ()
759
759
760
+ def test_beam_search (self ):
761
+ model_id = MODEL_NAMES ["llama" ]
762
+ ov_model_stateful = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = True )
763
+ ov_model_stateless = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True , stateful = False )
764
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id )
765
+
766
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
767
+ tokenizer .pad_token = tokenizer .eos_token
768
+ tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
769
+ ov_model_stateful .generation_config .eos_token_id = None
770
+ ov_model_stateless .generation_config .eos_token_id = None
771
+ transformers_model .generation_config .eos_token_id = None
772
+ ov_model_stateful .config .eos_token_id = None
773
+ ov_model_stateless .config .eos_token_id = None
774
+ transformers_model .config .eos_token_id = None
775
+
776
+ # beam search
777
+ gen_config = GenerationConfig (
778
+ max_new_tokens = 10 ,
779
+ min_new_tokens = 10 ,
780
+ num_beams = 4 ,
781
+ do_sample = False ,
782
+ eos_token_id = None ,
783
+ )
784
+
785
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
786
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
787
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
788
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
789
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
790
+ # beam sample
791
+ gen_config = GenerationConfig (
792
+ max_new_tokens = 10 ,
793
+ min_new_tokens = 10 ,
794
+ num_beams = 4 ,
795
+ do_sample = True ,
796
+ eos_token_id = None ,
797
+ top_k = 1 ,
798
+ )
799
+
800
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
801
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
802
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
803
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
804
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
805
+
806
+ # group beam search
807
+ gen_config = GenerationConfig (
808
+ max_new_tokens = 10 ,
809
+ min_new_tokens = 10 ,
810
+ num_beams = 4 ,
811
+ do_sample = False ,
812
+ eos_token_id = None ,
813
+ num_beam_groups = 2 ,
814
+ diversity_penalty = 0.0000001 ,
815
+ )
816
+
817
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
818
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
819
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
820
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
821
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
822
+
823
+ # constrained beam search
824
+ force_word = "cat"
825
+ force_words_ids = [tokenizer ([force_word ], add_special_tokens = False ).input_ids ]
826
+ gen_config = GenerationConfig (
827
+ max_new_tokens = 10 ,
828
+ min_new_tokens = 10 ,
829
+ num_beams = 4 ,
830
+ do_sample = False ,
831
+ eos_token_id = None ,
832
+ force_words_ids = force_words_ids ,
833
+ )
834
+
835
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
836
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
837
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
838
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
839
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
840
+
760
841
761
842
class OVModelForMaskedLMIntegrationTest (unittest .TestCase ):
762
843
SUPPORTED_ARCHITECTURES = (
0 commit comments