@@ -144,6 +144,26 @@ def test_pipeline(self, model_arch):
144
144
_ = pipe (text )
145
145
self .assertEqual (pipe .device , model .device )
146
146
147
+ @parameterized .expand (SUPPORTED_ARCHITECTURES )
148
+ def test_low_precision (self , model_arch ):
149
+ model_id = MODEL_NAMES [model_arch ]
150
+ ipex_model = self .IPEX_MODEL_CLASS .from_pretrained (model_id , export = True , torch_dtype = torch .bfloat16 )
151
+ self .assertEqual (ipex_model ._dtype , torch .bfloat16 )
152
+ transformers_model = self .IPEX_MODEL_CLASS .auto_model_class .from_pretrained (
153
+ model_id , torch_dtype = torch .bfloat16
154
+ )
155
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
156
+ inputs = "This is a sample input"
157
+ tokens = tokenizer (inputs , return_tensors = "pt" )
158
+ with torch .no_grad ():
159
+ transformers_outputs = transformers_model (** tokens )
160
+ outputs = ipex_model (** tokens )
161
+ # Compare tensor outputs
162
+ for output_name in {"logits" , "last_hidden_state" }:
163
+ if output_name in transformers_outputs :
164
+ self .assertEqual (outputs [output_name ].dtype , torch .bfloat16 )
165
+ self .assertTrue (torch .allclose (outputs [output_name ], transformers_outputs [output_name ], atol = 1e-1 ))
166
+
147
167
148
168
class IPEXModelForSequenceClassificationTest (IPEXModelTest ):
149
169
IPEX_MODEL_CLASS = IPEXModelForTokenClassification
0 commit comments