@@ -792,41 +792,16 @@ def test_beam_search(self, model_arch):
792
792
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
793
793
if model_arch in ["qwen" , "chatglm" ]:
794
794
return
795
-
796
- ov_model_stateful = OVModelForCausalLM .from_pretrained (
797
- model_id , export = True , use_cache = True , stateful = True , ** model_kwargs
798
- )
799
- ov_model_stateless = OVModelForCausalLM .from_pretrained (
800
- model_id , export = True , use_cache = True , stateful = False , ** model_kwargs
801
- )
802
- transformers_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_kwargs )
803
-
795
+
804
796
tokenizer = AutoTokenizer .from_pretrained (model_id , trust_remote_code = model_arch in self .REMOTE_CODE_MODELS )
805
- tokenizer .pad_token_id = tokenizer .eos_token_id
806
- tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
807
- ov_model_stateful .generation_config .eos_token_id = None
808
- ov_model_stateless .generation_config .eos_token_id = None
809
- transformers_model .generation_config .eos_token_id = None
810
- ov_model_stateful .config .eos_token_id = None
811
- ov_model_stateless .config .eos_token_id = None
812
- transformers_model .config .eos_token_id = None
813
-
814
- # beam search
815
- gen_config = GenerationConfig (
797
+ beam_search_gen_config = GenerationConfig (
816
798
max_new_tokens = 10 ,
817
799
min_new_tokens = 10 ,
818
800
num_beams = 4 ,
819
801
do_sample = False ,
820
802
eos_token_id = None ,
821
803
)
822
-
823
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
824
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
825
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
826
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
827
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
828
- # beam sample
829
- gen_config = GenerationConfig (
804
+ beam_sample_gen_config = GenerationConfig (
830
805
max_new_tokens = 10 ,
831
806
min_new_tokens = 10 ,
832
807
num_beams = 4 ,
@@ -835,14 +810,7 @@ def test_beam_search(self, model_arch):
835
810
top_k = 1 ,
836
811
)
837
812
838
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
839
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
840
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
841
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
842
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
843
-
844
- # group beam search
845
- gen_config = GenerationConfig (
813
+ group_beam_search_gen_config = GenerationConfig (
846
814
max_new_tokens = 10 ,
847
815
min_new_tokens = 10 ,
848
816
num_beams = 4 ,
@@ -851,17 +819,9 @@ def test_beam_search(self, model_arch):
851
819
num_beam_groups = 2 ,
852
820
diversity_penalty = 0.0000001 ,
853
821
)
854
-
855
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
856
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
857
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
858
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
859
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
860
-
861
- # constrained beam search
862
822
force_word = "cat"
863
823
force_words_ids = [tokenizer ([force_word ], add_special_tokens = False ).input_ids ]
864
- gen_config = GenerationConfig (
824
+ constrained_beam_search_gen_config = GenerationConfig (
865
825
max_new_tokens = 10 ,
866
826
min_new_tokens = 10 ,
867
827
num_beams = 4 ,
@@ -870,11 +830,31 @@ def test_beam_search(self, model_arch):
870
830
force_words_ids = force_words_ids ,
871
831
)
872
832
873
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
874
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
875
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
876
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
877
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
833
+ gen_configs = [
834
+ beam_search_gen_config , beam_sample_gen_config , group_beam_search_gen_config , constrained_beam_search_gen_config
835
+ ]
836
+ ov_model_stateful = OVModelForCausalLM .from_pretrained (
837
+ model_id , export = True , use_cache = True , stateful = True , ** model_kwargs
838
+ )
839
+ ov_model_stateless = OVModelForCausalLM .from_pretrained (
840
+ model_id , export = True , use_cache = True , stateful = False , ** model_kwargs
841
+ )
842
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_kwargs )
843
+ tokenizer .pad_token_id = tokenizer .eos_token_id
844
+ tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
845
+ ov_model_stateful .generation_config .eos_token_id = None
846
+ ov_model_stateless .generation_config .eos_token_id = None
847
+ transformers_model .generation_config .eos_token_id = None
848
+ ov_model_stateful .config .eos_token_id = None
849
+ ov_model_stateless .config .eos_token_id = None
850
+ transformers_model .config .eos_token_id = None
851
+
852
+ for gen_config in gen_configs :
853
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
854
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
855
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
856
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
857
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
878
858
879
859
880
860
class OVModelForMaskedLMIntegrationTest (unittest .TestCase ):
0 commit comments