Skip to content

Commit 40c26f7

Browse files
committed
refactor test
1 parent daecdac commit 40c26f7

File tree

1 file changed

+30
-50
lines changed

1 file changed

+30
-50
lines changed

tests/openvino/test_modeling.py

+30-50
Original file line numberDiff line numberDiff line change
@@ -792,41 +792,16 @@ def test_beam_search(self, model_arch):
792792
# Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
793793
if model_arch in ["qwen", "chatglm"]:
794794
return
795-
796-
ov_model_stateful = OVModelForCausalLM.from_pretrained(
797-
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
798-
)
799-
ov_model_stateless = OVModelForCausalLM.from_pretrained(
800-
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
801-
)
802-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
803-
795+
804796
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
805-
tokenizer.pad_token_id = tokenizer.eos_token_id
806-
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
807-
ov_model_stateful.generation_config.eos_token_id = None
808-
ov_model_stateless.generation_config.eos_token_id = None
809-
transformers_model.generation_config.eos_token_id = None
810-
ov_model_stateful.config.eos_token_id = None
811-
ov_model_stateless.config.eos_token_id = None
812-
transformers_model.config.eos_token_id = None
813-
814-
# beam search
815-
gen_config = GenerationConfig(
797+
beam_search_gen_config = GenerationConfig(
816798
max_new_tokens=10,
817799
min_new_tokens=10,
818800
num_beams=4,
819801
do_sample=False,
820802
eos_token_id=None,
821803
)
822-
823-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
824-
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
825-
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
826-
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
827-
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
828-
# beam sample
829-
gen_config = GenerationConfig(
804+
beam_sample_gen_config = GenerationConfig(
830805
max_new_tokens=10,
831806
min_new_tokens=10,
832807
num_beams=4,
@@ -835,14 +810,7 @@ def test_beam_search(self, model_arch):
835810
top_k=1,
836811
)
837812

838-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
839-
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
840-
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
841-
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
842-
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
843-
844-
# group beam search
845-
gen_config = GenerationConfig(
813+
group_beam_search_gen_config = GenerationConfig(
846814
max_new_tokens=10,
847815
min_new_tokens=10,
848816
num_beams=4,
@@ -851,17 +819,9 @@ def test_beam_search(self, model_arch):
851819
num_beam_groups=2,
852820
diversity_penalty=0.0000001,
853821
)
854-
855-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
856-
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
857-
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
858-
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
859-
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
860-
861-
# constrained beam search
862822
force_word = "cat"
863823
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
864-
gen_config = GenerationConfig(
824+
constrained_beam_search_gen_config = GenerationConfig(
865825
max_new_tokens=10,
866826
min_new_tokens=10,
867827
num_beams=4,
@@ -870,11 +830,31 @@ def test_beam_search(self, model_arch):
870830
force_words_ids=force_words_ids,
871831
)
872832

873-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
874-
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
875-
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
876-
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
877-
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
833+
gen_configs = [
834+
beam_search_gen_config, beam_sample_gen_config, group_beam_search_gen_config, constrained_beam_search_gen_config
835+
]
836+
ov_model_stateful = OVModelForCausalLM.from_pretrained(
837+
model_id, export=True, use_cache=True, stateful=True, **model_kwargs
838+
)
839+
ov_model_stateless = OVModelForCausalLM.from_pretrained(
840+
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
841+
)
842+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
843+
tokenizer.pad_token_id = tokenizer.eos_token_id
844+
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
845+
ov_model_stateful.generation_config.eos_token_id = None
846+
ov_model_stateless.generation_config.eos_token_id = None
847+
transformers_model.generation_config.eos_token_id = None
848+
ov_model_stateful.config.eos_token_id = None
849+
ov_model_stateless.config.eos_token_id = None
850+
transformers_model.config.eos_token_id = None
851+
852+
for gen_config in gen_configs:
853+
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
854+
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
855+
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
856+
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
857+
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
878858

879859

880860
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)