Skip to content

Commit b284bc9

Browse files
better testing
1 parent 1ab12a4 commit b284bc9

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

tests/openvino/test_modeling.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -1559,14 +1559,18 @@ def test_compare_output_attentions(self, model_arch):
15591559
self.assertIn("logits", ov_outputs)
15601560
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
15611561
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
1562-
self.assertTrue(
1563-
all(
1562+
self.assertTrue(len(ov_outputs.attentions) == len(transformers_outputs.attentions))
1563+
for i in range(len(ov_outputs.attentions)):
1564+
self.assertTrue(
15641565
torch.allclose(
1565-
torch.Tensor(ov_outputs.attentions[i]), transformers_outputs.attentions[i], atol=1e-4
1566-
)
1567-
for i in range(len(ov_outputs.attentions))
1566+
torch.Tensor(ov_outputs.attentions[i]),
1567+
transformers_outputs.attentions[i],
1568+
atol=1e-4, # attentions are accurate
1569+
rtol=1e-4, # attentions are accurate
1570+
),
1571+
f"Attention mismatch at layer {i}",
15681572
)
1569-
)
1573+
15701574
del transformers_model
15711575
del ov_model
15721576
gc.collect()
@@ -1592,14 +1596,17 @@ def test_compare_output_hidden_states(self, model_arch):
15921596
self.assertIn("logits", ov_outputs)
15931597
self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type])
15941598
self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-4))
1595-
self.assertTrue(
1596-
all(
1599+
self.assertTrue(len(ov_outputs.hidden_states) == len(transformers_outputs.hidden_states))
1600+
for i in range(len(ov_outputs.hidden_states)):
1601+
self.assertTrue(
15971602
torch.allclose(
1598-
torch.Tensor(ov_outputs.hidden_states[i]), transformers_outputs.hidden_states[i], atol=1e-4
1599-
)
1600-
for i in range(len(ov_outputs.hidden_states))
1603+
torch.Tensor(ov_outputs.hidden_states[i]),
1604+
transformers_outputs.hidden_states[i],
1605+
atol=1e-3, # hidden states are less accurate
1606+
rtol=1e-2, # hidden states are less accurate
1607+
),
1608+
f"Hidden states mismatch at layer {i}",
16011609
)
1602-
)
16031610
del transformers_model
16041611
del ov_model
16051612
gc.collect()

0 commit comments

Comments
 (0)