@@ -1559,14 +1559,18 @@ def test_compare_output_attentions(self, model_arch):
1559
1559
self .assertIn ("logits" , ov_outputs )
1560
1560
self .assertIsInstance (ov_outputs .logits , TENSOR_ALIAS_TO_TYPE [input_type ])
1561
1561
self .assertTrue (torch .allclose (torch .Tensor (ov_outputs .logits ), transformers_outputs .logits , atol = 1e-4 ))
1562
- self .assertTrue (
1563
- all (
1562
+ self .assertTrue (len (ov_outputs .attentions ) == len (transformers_outputs .attentions ))
1563
+ for i in range (len (ov_outputs .attentions )):
1564
+ self .assertTrue (
1564
1565
torch .allclose (
1565
- torch .Tensor (ov_outputs .attentions [i ]), transformers_outputs .attentions [i ], atol = 1e-4
1566
- )
1567
- for i in range (len (ov_outputs .attentions ))
1566
+ torch .Tensor (ov_outputs .attentions [i ]),
1567
+ transformers_outputs .attentions [i ],
1568
+ atol = 1e-4 , # attentions are accurate
1569
+ rtol = 1e-4 , # attentions are accurate
1570
+ ),
1571
+ f"Attention mismatch at layer { i } " ,
1568
1572
)
1569
- )
1573
+
1570
1574
del transformers_model
1571
1575
del ov_model
1572
1576
gc .collect ()
@@ -1592,14 +1596,17 @@ def test_compare_output_hidden_states(self, model_arch):
1592
1596
self .assertIn ("logits" , ov_outputs )
1593
1597
self .assertIsInstance (ov_outputs .logits , TENSOR_ALIAS_TO_TYPE [input_type ])
1594
1598
self .assertTrue (torch .allclose (torch .Tensor (ov_outputs .logits ), transformers_outputs .logits , atol = 1e-4 ))
1595
- self .assertTrue (
1596
- all (
1599
+ self .assertTrue (len (ov_outputs .hidden_states ) == len (transformers_outputs .hidden_states ))
1600
+ for i in range (len (ov_outputs .hidden_states )):
1601
+ self .assertTrue (
1597
1602
torch .allclose (
1598
- torch .Tensor (ov_outputs .hidden_states [i ]), transformers_outputs .hidden_states [i ], atol = 1e-4
1599
- )
1600
- for i in range (len (ov_outputs .hidden_states ))
1603
+ torch .Tensor (ov_outputs .hidden_states [i ]),
1604
+ transformers_outputs .hidden_states [i ],
1605
+ atol = 1e-3 , # hidden states are less accurate
1606
+ rtol = 1e-2 , # hidden states are less accurate
1607
+ ),
1608
+ f"Hidden states mismatch at layer { i } " ,
1601
1609
)
1602
- )
1603
1610
del transformers_model
1604
1611
del ov_model
1605
1612
gc .collect ()
0 commit comments