@@ -241,7 +241,6 @@ def test_compare_to_transformers(self, model_arch):
241
241
model_id = MODEL_NAMES [model_arch ]
242
242
set_seed (SEED )
243
243
dtype = torch .float16 if IS_XPU_AVAILABLE else torch .float32
244
- # Test model forward do not need cache.
245
244
ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
246
245
self .assertIsInstance (ipex_model .config , PretrainedConfig )
247
246
tokenizer = AutoTokenizer .from_pretrained (model_id )
@@ -275,6 +274,38 @@ def test_compare_to_transformers(self, model_arch):
275
274
self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
276
275
self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
277
276
277
+ @parameterized .expand (SUPPORTED_ARCHITECTURES )
278
+ def test_forward (self , model_arch ):
279
+ model_id = MODEL_NAMES [model_arch ]
280
+ set_seed (SEED )
281
+ dtype = torch .float16 if IS_XPU_AVAILABLE else torch .float32
282
+ ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
283
+ self .assertIsInstance (ipex_model .config , PretrainedConfig )
284
+ input_ids = torch .Tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]]).to (torch .long )
285
+ outputs = ipex_model (input_ids )
286
+
287
+ self .assertIsInstance (outputs .logits , torch .Tensor )
288
+
289
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
290
+ with torch .no_grad ():
291
+ transformers_outputs = transformers_model (input_ids )
292
+
293
+ # Test re-load model
294
+ with tempfile .TemporaryDirectory () as tmpdirname :
295
+ ipex_model .save_pretrained (tmpdirname )
296
+ loaded_model = self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
297
+ loaded_model_outputs = loaded_model (input_ids )
298
+
299
+ # Test init method
300
+ init_model = self .IPEX_MODEL_CLASS (transformers_model )
301
+ init_model_outputs = init_model (input_ids )
302
+
303
+ # Compare tensor outputs
304
+ self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 1e-4 ))
305
+ # To avoid float pointing error
306
+ self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
307
+ self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
308
+
278
309
@parameterized .expand (SUPPORTED_ARCHITECTURES )
279
310
def test_pipeline (self , model_arch ):
280
311
dtype = torch .float16 if IS_XPU_AVAILABLE else torch .float32
0 commit comments