@@ -1658,6 +1658,21 @@ def test_compare_to_transformers(self, model_arch):
1658
1658
transformers_outputs = transformers_model (** tokens , ** decoder_inputs )
1659
1659
# Compare tensor outputs
1660
1660
self .assertTrue (torch .allclose (ov_outputs .logits , transformers_outputs .logits , atol = 1e-4 ))
1661
+ gen_config = GenerationConfig (
1662
+ max_new_tokens = 10 ,
1663
+ min_new_tokens = 10 ,
1664
+ num_beams = 2 ,
1665
+ do_sample = False ,
1666
+ eos_token_id = None ,
1667
+ )
1668
+
1669
+ set_seed (SEED )
1670
+ generated_tokens = transformers_model .generate (** tokens , generation_config = gen_config )
1671
+ set_seed (SEED )
1672
+ ov_generated_tokens = ov_model .generate (** tokens , generation_config = gen_config )
1673
+
1674
+ self .assertTrue (torch .equal (generated_tokens , ov_generated_tokens ))
1675
+
1661
1676
del transformers_model
1662
1677
del ov_model
1663
1678
@@ -2355,12 +2370,12 @@ def test_compare_to_transformers(self, model_arch):
2355
2370
2356
2371
processor = get_preprocessor (model_id )
2357
2372
data = self ._generate_random_audio_data ()
2358
- features = processor .feature_extractor (data , return_tensors = "pt" )
2373
+ pt_features = processor .feature_extractor (data , return_tensors = "pt" )
2359
2374
decoder_start_token_id = transformers_model .config .decoder_start_token_id
2360
2375
decoder_inputs = {"decoder_input_ids" : torch .ones ((1 , 1 ), dtype = torch .long ) * decoder_start_token_id }
2361
2376
2362
2377
with torch .no_grad ():
2363
- transformers_outputs = transformers_model (** features , ** decoder_inputs )
2378
+ transformers_outputs = transformers_model (** pt_features , ** decoder_inputs )
2364
2379
2365
2380
for input_type in ["pt" , "np" ]:
2366
2381
features = processor .feature_extractor (data , return_tensors = input_type )
@@ -2373,6 +2388,21 @@ def test_compare_to_transformers(self, model_arch):
2373
2388
# Compare tensor outputs
2374
2389
self .assertTrue (torch .allclose (torch .Tensor (ov_outputs .logits ), transformers_outputs .logits , atol = 1e-3 ))
2375
2390
2391
+ gen_config = GenerationConfig (
2392
+ max_new_tokens = 10 ,
2393
+ min_new_tokens = 10 ,
2394
+ num_beams = 2 ,
2395
+ do_sample = False ,
2396
+ eos_token_id = None ,
2397
+ )
2398
+
2399
+ set_seed (SEED )
2400
+ generated_tokens = transformers_model .generate (** pt_features , generation_config = gen_config )
2401
+ set_seed (SEED )
2402
+ ov_generated_tokens = ov_model .generate (** pt_features , generation_config = gen_config )
2403
+
2404
+ self .assertTrue (torch .equal (generated_tokens , ov_generated_tokens ))
2405
+
2376
2406
del transformers_model
2377
2407
del ov_model
2378
2408
gc .collect ()
0 commit comments