@@ -793,40 +793,15 @@ def test_beam_search(self, model_arch):
793
793
if model_arch in ["qwen" , "chatglm" ]:
794
794
return
795
795
796
- ov_model_stateful = OVModelForCausalLM .from_pretrained (
797
- model_id , export = True , use_cache = True , stateful = True , ** model_kwargs
798
- )
799
- ov_model_stateless = OVModelForCausalLM .from_pretrained (
800
- model_id , export = True , use_cache = True , stateful = False , ** model_kwargs
801
- )
802
- transformers_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_kwargs )
803
-
804
796
tokenizer = AutoTokenizer .from_pretrained (model_id , trust_remote_code = model_arch in self .REMOTE_CODE_MODELS )
805
- tokenizer .pad_token_id = tokenizer .eos_token_id
806
- tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
807
- ov_model_stateful .generation_config .eos_token_id = None
808
- ov_model_stateless .generation_config .eos_token_id = None
809
- transformers_model .generation_config .eos_token_id = None
810
- ov_model_stateful .config .eos_token_id = None
811
- ov_model_stateless .config .eos_token_id = None
812
- transformers_model .config .eos_token_id = None
813
-
814
- # beam search
815
- gen_config = GenerationConfig (
797
+ beam_search_gen_config = GenerationConfig (
816
798
max_new_tokens = 10 ,
817
799
min_new_tokens = 10 ,
818
800
num_beams = 4 ,
819
801
do_sample = False ,
820
802
eos_token_id = None ,
821
803
)
822
-
823
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
824
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
825
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
826
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
827
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
828
- # beam sample
829
- gen_config = GenerationConfig (
804
+ beam_sample_gen_config = GenerationConfig (
830
805
max_new_tokens = 10 ,
831
806
min_new_tokens = 10 ,
832
807
num_beams = 4 ,
@@ -835,14 +810,7 @@ def test_beam_search(self, model_arch):
835
810
top_k = 1 ,
836
811
)
837
812
838
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
839
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
840
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
841
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
842
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
843
-
844
- # group beam search
845
- gen_config = GenerationConfig (
813
+ group_beam_search_gen_config = GenerationConfig (
846
814
max_new_tokens = 10 ,
847
815
min_new_tokens = 10 ,
848
816
num_beams = 4 ,
@@ -851,17 +819,9 @@ def test_beam_search(self, model_arch):
851
819
num_beam_groups = 2 ,
852
820
diversity_penalty = 0.0000001 ,
853
821
)
854
-
855
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
856
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
857
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
858
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
859
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
860
-
861
- # constrained beam search
862
822
force_word = "cat"
863
823
force_words_ids = [tokenizer ([force_word ], add_special_tokens = False ).input_ids ]
864
- gen_config = GenerationConfig (
824
+ constrained_beam_search_gen_config = GenerationConfig (
865
825
max_new_tokens = 10 ,
866
826
min_new_tokens = 10 ,
867
827
num_beams = 4 ,
@@ -870,11 +830,34 @@ def test_beam_search(self, model_arch):
870
830
force_words_ids = force_words_ids ,
871
831
)
872
832
873
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
874
- ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
875
- self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
876
- ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
877
- self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
833
+ gen_configs = [
834
+ beam_search_gen_config ,
835
+ beam_sample_gen_config ,
836
+ group_beam_search_gen_config ,
837
+ constrained_beam_search_gen_config ,
838
+ ]
839
+ ov_model_stateful = OVModelForCausalLM .from_pretrained (
840
+ model_id , export = True , use_cache = True , stateful = True , ** model_kwargs
841
+ )
842
+ ov_model_stateless = OVModelForCausalLM .from_pretrained (
843
+ model_id , export = True , use_cache = True , stateful = False , ** model_kwargs
844
+ )
845
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_kwargs )
846
+ tokenizer .pad_token_id = tokenizer .eos_token_id
847
+ tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
848
+ ov_model_stateful .generation_config .eos_token_id = None
849
+ ov_model_stateless .generation_config .eos_token_id = None
850
+ transformers_model .generation_config .eos_token_id = None
851
+ ov_model_stateful .config .eos_token_id = None
852
+ ov_model_stateless .config .eos_token_id = None
853
+ transformers_model .config .eos_token_id = None
854
+
855
+ for gen_config in gen_configs :
856
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
857
+ ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
858
+ self .assertTrue (torch .allclose (ov_stateful_outputs , transformers_outputs ))
859
+ ov_stateless_outputs = ov_model_stateless .generate (** tokens , generation_config = gen_config )
860
+ self .assertTrue (torch .allclose (ov_stateless_outputs , transformers_outputs ))
878
861
879
862
880
863
class OVModelForMaskedLMIntegrationTest (unittest .TestCase ):
0 commit comments