@@ -230,11 +230,13 @@ def __init__(
230
230
231
231
# device
232
232
self .device = get_accelerator (kwargs .pop ("device" , "auto" )).current_device_name ()
233
- self .model .to (self .device )
233
+ if not use_layer_wise :
234
+ self .model .to (self .device )
234
235
self .is_ready = False
235
236
236
237
self .use_layer_wise = use_layer_wise
237
- self .model_path = model_path
238
+ if use_layer_wise :
239
+ self .prepare_layer_wise (model_path )
238
240
239
241
# dataloader
240
242
self .use_max_length = use_max_length
@@ -243,6 +245,20 @@ def __init__(
243
245
self .dataloader = []
244
246
self .nsamples = nsamples
245
247
248
+ def prepare_layer_wise (self , model_path ):
249
+ import os
250
+
251
+ from neural_compressor .torch .algorithms .layer_wise import LWQ_WORKSPACE , get_path , register_weight_hooks
252
+
253
+ os .makedirs (LWQ_WORKSPACE , exist_ok = True )
254
+ if model_path == "" :
255
+ model_path = self .model .path
256
+ assert model_path , "model_path should not be None."
257
+ self .model_path = get_path (model_path )
258
+ register_weight_hooks (
259
+ self .model , self .model_path , device = self .device , clean_weight = True , saved_path = LWQ_WORKSPACE
260
+ )
261
+
246
262
def get_full_layer_name (self , sub_layer_name , block_idx ):
247
263
transformer_name = self .gptq_related_blocks ["transformers_name" ]
248
264
return "." .join ([transformer_name , str (block_idx ), sub_layer_name ])
@@ -413,7 +429,6 @@ def execute_quantization(self, means=None, stds=None):
413
429
# Step1: prepare quantization (calibration datasets)
414
430
415
431
logger .info ("Begin ====>" )
416
- model_path = self .model_path
417
432
418
433
# Step2: run gptq quantization in a transformer block-wise manner.
419
434
gptq_config = {}
@@ -450,7 +465,7 @@ def execute_quantization(self, means=None, stds=None):
450
465
if self .use_layer_wise : # pragma: no cover
451
466
from neural_compressor .torch .algorithms .layer_wise import load_value
452
467
453
- W = load_value (self .model , full_layer_name + ".weight" , model_path )
468
+ W = load_value (self .model , full_layer_name + ".weight" , self . model_path )
454
469
else :
455
470
W = sub_layers [layer_name ].weight .data .clone ()
456
471
@@ -489,7 +504,7 @@ def tmp(_, inp, out):
489
504
from neural_compressor .torch .algorithms .layer_wise import load_value
490
505
491
506
full_layer_name = self .get_full_layer_name (layer_name , block_idx )
492
- W = load_value (self .model , full_layer_name + ".weight" , model_path )
507
+ W = load_value (self .model , full_layer_name + ".weight" , self . model_path )
493
508
else :
494
509
W = sub_layers [layer_name ].weight .data .clone ()
495
510
accelerator .mark_step ()
@@ -518,7 +533,7 @@ def tmp(_, inp, out):
518
533
if n == "weight" :
519
534
set_module_tensor_to_device (self .model , param_name , self .device , Q )
520
535
else :
521
- value = load_value (self .model , param_name , model_path )
536
+ value = load_value (self .model , param_name , self . model_path )
522
537
set_module_tensor_to_device (self .model , param_name , self .device , value )
523
538
# sub_layer.weight.data = Q
524
539
torch .save (sub_layer .state_dict (), LWQ_WORKSPACE + f"/{ full_layer_name } .pt" )
@@ -562,7 +577,13 @@ def tmp(_, inp, out):
562
577
gptq_perm = gptq_config [self .get_full_layer_name (layer_name , block_idx )]["perm" ]
563
578
else :
564
579
gptq_perm = None
565
- Q = sub_layers [layer_name ].weight .data
580
+ if self .use_layer_wise :
581
+ state_dict = torch .load (LWQ_WORKSPACE + f"/{ self .get_full_layer_name (layer_name , block_idx )} .pt" )
582
+ Q = state_dict ["weight" ].data
583
+ bias = state_dict ["bias" ] if "bias" in state_dict .keys () else None
584
+
585
+ else :
586
+ Q = sub_layers [layer_name ].weight .data
566
587
if weight_config_this_layer ["act_order" ]:
567
588
Q .copy_ (Q [:, gptq_perm ])
568
589
if is_transformers_imported () and isinstance (sub_layers [layer_name ], transformers .Conv1D ):
@@ -591,18 +612,21 @@ def tmp(_, inp, out):
591
612
scale = scale .t_ ().contiguous ()
592
613
zp = zp .t_ ().contiguous () if zp is not None else zp
593
614
615
+ if not self .use_layer_wise :
616
+ bias = sub_layers [layer_name ].bias
617
+
594
618
new_module = WeightOnlyLinear (
595
619
in_features ,
596
620
out_features ,
597
621
dtype = weight_config_this_layer ["dtype" ],
598
622
bits = weight_config_this_layer ["bits" ],
599
623
group_size = weight_config_this_layer ["group_size" ],
600
624
zp = gptq_zp is not None ,
601
- bias = sub_layers [ layer_name ]. bias is not None ,
625
+ bias = bias is not None ,
602
626
g_idx = gptq_perm is not None ,
603
627
device = self .device ,
604
628
)
605
- new_module .pack (int_weight , gptq_scale , gptq_zp , sub_layers [ layer_name ]. bias , gptq_perm )
629
+ new_module .pack (int_weight , gptq_scale , gptq_zp , bias , gptq_perm )
606
630
set_module (transformer_block , layer_name , new_module )
607
631
del gptq_for_this_block
608
632
torch .cuda .empty_cache ()
@@ -1019,8 +1043,10 @@ def prepare(
1019
1043
def convert (self , model , * args , ** kwargs ):
1020
1044
self .gptq_quantizer .model = model
1021
1045
self .gptq_quantizer .remove_prepare_for_calibration ()
1046
+
1022
1047
q_model , gptq_config = self .gptq_quantizer .execute_quantization ()
1023
- q_model = q_model .to (self .model_device )
1048
+ if not self .gptq_quantizer .use_layer_wise :
1049
+ q_model = q_model .to (self .model_device )
1024
1050
q_model .gptq_config = gptq_config
1025
1051
logger .info ("GPTQ quantizing done." )
1026
1052
return q_model
0 commit comments