@@ -2318,21 +2318,28 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
2318
2318
"bloom" ,
2319
2319
"codegen" ,
2320
2320
"falcon" ,
2321
- "gemma" ,
2322
2321
"gpt2" ,
2323
2322
"gpt_bigcode" ,
2324
2323
"gpt_neo" ,
2325
2324
"gpt_neox" ,
2326
2325
"gptj" ,
2327
- "granite" ,
2328
2326
"llama" ,
2329
2327
"mistral" ,
2330
- "mpt" ,
2331
2328
"opt" ,
2332
2329
]
2333
2330
2334
- if check_if_transformers_greater ("4.40" ):
2335
- SUPPORTED_ARCHITECTURES .extend (["gemma" , "phi3" , "qwen2" ])
2331
+ if check_if_transformers_greater ("4.37" ):
2332
+ SUPPORTED_ARCHITECTURES .append ("qwen2" )
2333
+
2334
+ if check_if_transformers_greater ("4.38" ):
2335
+ SUPPORTED_ARCHITECTURES .append ("gemma" )
2336
+
2337
+ # TODO: fix "mpt" for which inference fails for transformers < v4.41
2338
+ if check_if_transformers_greater ("4.41" ):
2339
+ SUPPORTED_ARCHITECTURES .extend (["phi3" , "mpt" ])
2340
+
2341
+ if check_if_transformers_greater ("4.45" ):
2342
+ SUPPORTED_ARCHITECTURES .append ("granite" )
2336
2343
2337
2344
FULL_GRID = {
2338
2345
"model_arch" : SUPPORTED_ARCHITECTURES ,
@@ -2445,7 +2452,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
2445
2452
transformers_model = AutoModelForCausalLM .from_pretrained (model_id )
2446
2453
transformers_model = transformers_model .eval ()
2447
2454
tokenizer = get_preprocessor (model_id )
2448
- tokens = tokenizer ("This is a sample output " , return_tensors = "pt" )
2455
+ tokens = tokenizer ("This is a sample input " , return_tensors = "pt" )
2449
2456
position_ids = None
2450
2457
if model_arch .replace ("_" , "-" ) in MODEL_TYPES_REQUIRING_POSITION_IDS :
2451
2458
input_shape = tokens ["input_ids" ].shape
@@ -2467,7 +2474,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
2467
2474
# Compare batched generation.
2468
2475
tokenizer .pad_token_id = tokenizer .eos_token_id
2469
2476
tokenizer .padding_side = "left"
2470
- tokens = tokenizer (["Today is a nice day and I am longer " , "This is me " ], return_tensors = "pt" , padding = True )
2477
+ tokens = tokenizer (["This is" , "This is a sample input " ], return_tensors = "pt" , padding = True )
2471
2478
onnx_model .generation_config .eos_token_id = None
2472
2479
transformers_model .generation_config .eos_token_id = None
2473
2480
onnx_model .config .eos_token_id = None
@@ -4598,14 +4605,14 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str):
4598
4605
)
4599
4606
4600
4607
self .assertTrue (torch .equal (outputs_model_with_pkv , outputs_model_without_pkv ))
4601
- self . assertEqual (
4602
- outputs_model_with_pkv . shape [ 1 ],
4603
- self . GENERATION_LENGTH + 2 if model_arch == "whisper" else self .GENERATION_LENGTH + 1 ,
4604
- )
4605
- self .assertEqual (
4606
- outputs_model_without_pkv . shape [ 1 ],
4607
- self .GENERATION_LENGTH + 2 if model_arch == "whisper" else self . GENERATION_LENGTH + 1 ,
4608
- )
4608
+
4609
+ if model_arch == "whisper" and check_if_transformers_greater ( "4.43" ):
4610
+ gen_length = self .GENERATION_LENGTH + 2
4611
+ else :
4612
+ gen_length = self .GENERATION_LENGTH + 1
4613
+
4614
+ self .assertEqual ( outputs_model_with_pkv . shape [ 1 ], gen_length )
4615
+ self . assertEqual ( outputs_model_without_pkv . shape [ 1 ], gen_length )
4609
4616
4610
4617
self .GENERATION_LENGTH = generation_length
4611
4618
if os .environ .get ("TEST_LEVEL" , 0 ) == "1" :
0 commit comments