Skip to content

Commit 3a966c5

Browse files
committed
testiong low precision ipex model
1 parent 08717f2 commit 3a966c5

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/ipex/test_modeling.py

+20
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ def test_pipeline(self, model_arch):
144144
_ = pipe(text)
145145
self.assertEqual(pipe.device, model.device)
146146

147+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
148+
def test_low_precision(self, model_arch):
149+
model_id = MODEL_NAMES[model_arch]
150+
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16)
151+
self.assertEqual(ipex_model._dtype, torch.bfloat16)
152+
transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(
153+
model_id, torch_dtype=torch.bfloat16
154+
)
155+
tokenizer = AutoTokenizer.from_pretrained(model_id)
156+
inputs = "This is a sample input"
157+
tokens = tokenizer(inputs, return_tensors="pt")
158+
with torch.no_grad():
159+
transformers_outputs = transformers_model(**tokens)
160+
outputs = ipex_model(**tokens)
161+
# Compare tensor outputs
162+
for output_name in {"logits", "last_hidden_state"}:
163+
if output_name in transformers_outputs:
164+
self.assertEqual(outputs[output_name].dtype, torch.bfloat16)
165+
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-1))
166+
147167

148168
class IPEXModelForSequenceClassificationTest(IPEXModelTest):
149169
IPEX_MODEL_CLASS = IPEXModelForTokenClassification

0 commit comments

Comments
 (0)