57
57
from transformers .testing_utils import slow
58
58
from utils_tests import MODEL_NAMES
59
59
60
+ from optimum .exporters .openvino .model_patcher import patch_update_causal_mask
60
61
from optimum .intel import (
61
62
OVModelForAudioClassification ,
62
63
OVModelForAudioFrameClassification ,
@@ -647,6 +648,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
647
648
if is_transformers_version (">=" , "4.40.0" ):
648
649
SUPPORTED_ARCHITECTURES += (
649
650
"gemma" ,
651
+ "gemma2" ,
650
652
"olmo" ,
651
653
"stablelm" ,
652
654
"starcoder2" ,
@@ -728,7 +730,8 @@ def test_compare_to_transformers(self, model_arch):
728
730
self .assertTrue (torch .allclose (ov_outputs .logits , transformers_outputs .logits , equal_nan = True , atol = 1e-4 ))
729
731
730
732
# Qwen tokenizer does not support padding
731
- if model_arch == "qwen" :
733
+
734
+ if model_arch in ["qwen" ]:
732
735
return
733
736
734
737
if model_arch not in ["chatglm" , "glm4" , "persimmon" ]:
@@ -753,7 +756,16 @@ def test_compare_to_transformers(self, model_arch):
753
756
)
754
757
755
758
ov_outputs = ov_model .generate (** tokens , generation_config = gen_config )
756
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
759
+ additional_inputs = {}
760
+ # gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
761
+ # align cache representation in torch model
762
+ if model_arch == "gemma2" :
763
+ patch_update_causal_mask (transformers_model , "4.43.0" )
764
+ transformers_model ._supports_cache_class = True
765
+ from transformers .cache_utils import DynamicCache
766
+
767
+ additional_inputs = {"past_key_values" : DynamicCache ()}
768
+ transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config , ** additional_inputs )
757
769
self .assertTrue (torch .allclose (ov_outputs , transformers_outputs ))
758
770
759
771
del transformers_model
@@ -921,8 +933,8 @@ def test_beam_search(self, model_arch):
921
933
"config" : AutoConfig .from_pretrained (model_id , trust_remote_code = True ),
922
934
"trust_remote_code" : True ,
923
935
}
924
- # Qwen tokenizer does not support padding, chatgm testing model produces nan that incompatible with beam search
925
- if model_arch in ["qwen" , "chatglm" ]:
936
+ # Qwen tokenizer does not support padding, chatglm, glm4 testing models produce nan that incompatible with beam search
937
+ if model_arch in ["qwen" , "chatglm" , "glm4" ]:
926
938
return
927
939
928
940
tokenizer = AutoTokenizer .from_pretrained (model_id , trust_remote_code = model_arch in self .REMOTE_CODE_MODELS )
@@ -988,6 +1000,12 @@ def test_beam_search(self, model_arch):
988
1000
989
1001
if model_arch == "arctic" :
990
1002
transformers_model .to (torch .float32 )
1003
+ additional_inputs = {}
1004
+ # gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache, align cache representation in torch model
1005
+ if model_arch == "gemma2" :
1006
+ patch_update_causal_mask (transformers_model , "4.43.0" )
1007
+ transformers_model ._supports_cache_class = True
1008
+ from transformers .cache_utils import DynamicCache
991
1009
tokenizer .pad_token_id = tokenizer .eos_token_id
992
1010
tokens = tokenizer (["Today is a nice day and I am longer" , "This is me" ], return_tensors = "pt" , padding = True )
993
1011
tokens .pop ("token_type_ids" , None )
@@ -1002,7 +1020,12 @@ def test_beam_search(self, model_arch):
1002
1020
if gen_config .do_sample and model_arch in ["baichuan2-13b" , "olmo" ]:
1003
1021
continue
1004
1022
set_seed (SEED )
1005
- transformers_outputs = transformers_model .generate (** tokens , generation_config = gen_config )
1023
+
1024
+ if model_arch == "gemma2" :
1025
+ additional_inputs = {"past_key_values" : DynamicCache ()}
1026
+ transformers_outputs = transformers_model .generate (
1027
+ ** tokens , generation_config = gen_config , ** additional_inputs
1028
+ )
1006
1029
set_seed (SEED )
1007
1030
ov_stateful_outputs = ov_model_stateful .generate (** tokens , generation_config = gen_config )
1008
1031
self .assertTrue (
0 commit comments