@@ -483,6 +483,45 @@ def test_bnb(self):
483
483
self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
484
484
self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
485
485
486
+ @unittest .skipIf (not is_auto_awq_available (), reason = "Test requires autoawq" )
487
+ def test_awq (self ):
488
+ model_id = "PrunaAI/JackFram-llama-68m-AWQ-4bit-smashed"
489
+ set_seed (SEED )
490
+ dtype = torch .float16 if IS_XPU_AVAILABLE else torch .float32
491
+ # Test model forward do not need cache.
492
+ ipex_model = IPEXModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
493
+ self .assertIsInstance (ipex_model .config , PretrainedConfig )
494
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
495
+ tokens = tokenizer (
496
+ "This is a sample" ,
497
+ return_tensors = "pt" ,
498
+ return_token_type_ids = False ,
499
+ ).to (DEVICE )
500
+ inputs = ipex_model .prepare_inputs_for_generation (** tokens )
501
+ outputs = ipex_model (** inputs )
502
+
503
+ self .assertIsInstance (outputs .logits , torch .Tensor )
504
+
505
+ transformers_model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = dtype , device_map = DEVICE )
506
+ with torch .no_grad ():
507
+ transformers_outputs = transformers_model (** tokens )
508
+
509
+ # Test re-load model
510
+ with tempfile .TemporaryDirectory () as tmpdirname :
511
+ ipex_model .save_pretrained (tmpdirname )
512
+ loaded_model = self .IPEX_MODEL_CLASS .from_pretrained (tmpdirname , torch_dtype = dtype , device_map = DEVICE )
513
+ loaded_model_outputs = loaded_model (** inputs )
514
+
515
+ # Test init method
516
+ init_model = self .IPEX_MODEL_CLASS (transformers_model )
517
+ init_model_outputs = init_model (** inputs )
518
+
519
+ # Compare tensor outputs
520
+ self .assertTrue (torch .allclose (outputs .logits , transformers_outputs .logits , atol = 5e-2 ))
521
+ # To avoid float pointing error
522
+ self .assertTrue (torch .allclose (outputs .logits , loaded_model_outputs .logits , atol = 1e-7 ))
523
+ self .assertTrue (torch .allclose (outputs .logits , init_model_outputs .logits , atol = 1e-7 ))
524
+
486
525
487
526
class IPEXModelForAudioClassificationTest (unittest .TestCase ):
488
527
IPEX_MODEL_CLASS = IPEXModelForAudioClassification
0 commit comments