@@ -498,8 +498,7 @@ def export_compressed_model(
498
498
gptq_config = self .gptq_config if hasattr (self , "gptq_config" ) else {}
499
499
500
500
autoround_config = self .autoround_config if hasattr (self , "autoround_config" ) else {}
501
-
502
- if gptq_config or (autoround_config and device == "xpu" ):
501
+ if gptq_config :
503
502
for k , v in weight_config .items ():
504
503
logger .debug (f"Compressing { k } on device { device } " )
505
504
if v ["dtype" ] == "fp32" :
@@ -558,19 +557,54 @@ def export_compressed_model(
558
557
)
559
558
new_module .pack (int_weight , gptq_scale , gptq_zp , m .bias , gptq_perm )
560
559
set_module (self .model , k , new_module )
561
- elif autoround_config and (device == "cpu" or device == "auto" ):
562
- from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
560
+ elif autoround_config :
561
+ if device == "xpu" :
562
+ for k , v in weight_config .items ():
563
+ logger .debug (f"Compressing { k } on device { device } " )
564
+ if v ["dtype" ] == "fp32" :
565
+ continue
566
+ else :
567
+ dtype = v ["dtype" ]
568
+ num_bits = v ["bits" ]
569
+ group_size = v ["group_size" ]
570
+ scheme = v ["scheme" ]
571
+ m = fetch_module (self .model , k )
572
+ autoround_conf = autoround_config [k ]
573
+ fp32_weight = m .weight .data
574
+ autoround_scale = torch .tensor (autoround_conf ["scale" ], dtype = torch .float32 )
575
+ autoround_zp = None if scheme == "sym" else torch .tensor (autoround_conf ["zero" ], dtype = torch .int32 )
576
+ int_weight = quant_weight_w_scale (fp32_weight , autoround_scale , autoround_zp , group_size )
577
+ int_weight = int_weight .type (torch .int32 )
578
+ new_module = WeightOnlyLinear (
579
+ m .in_features ,
580
+ m .out_features ,
581
+ num_bits ,
582
+ group_size ,
583
+ dtype = dtype ,
584
+ zp = autoround_zp is not None ,
585
+ bias = m .bias is not None ,
586
+ g_idx = None ,
587
+ compression_dtype = compression_dtype ,
588
+ compression_dim = compression_dim ,
589
+ scale_dtype = scale_dtype ,
590
+ device = device ,
591
+ use_optimum_format = use_optimum_format ,
592
+ )
593
+ new_module .pack (int_weight , autoround_scale , autoround_zp , m .bias , None )
594
+ set_module (self .model , k , new_module )
595
+ else :
596
+ from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
563
597
564
- self .model = pack_model (
565
- self .model ,
566
- weight_config = autoround_config ,
567
- enable_full_range = enable_full_range ,
568
- compression_dtype = compression_dtype ,
569
- compression_dim = compression_dim ,
570
- device = device ,
571
- use_optimum_format = use_optimum_format ,
572
- inplace = True ,
573
- )
598
+ self .model = pack_model (
599
+ self .model ,
600
+ weight_config = autoround_config ,
601
+ enable_full_range = enable_full_range ,
602
+ compression_dtype = compression_dtype ,
603
+ compression_dim = compression_dim ,
604
+ device = device ,
605
+ use_optimum_format = use_optimum_format ,
606
+ inplace = True ,
607
+ )
574
608
else :
575
609
for k , v in weight_config .items ():
576
610
logger .debug (f"Compressing { k } on device { device } " )
0 commit comments