@@ -718,10 +718,10 @@ def tmp(_, inp, out):
718
718
for n , p in sub_layer .named_parameters ():
719
719
param_name = full_layer_name + "." + n
720
720
if n == "weight" :
721
- set_module_tensor_to_device (self .model , param_name , self .device , Q )
721
+ set_module_tensor_to_device (self .model , param_name , self .device , Q , dtype = Q . dtype )
722
722
else :
723
723
value = load_value (self .model , param_name , model_path )
724
- set_module_tensor_to_device (self .model , param_name , self .device , value )
724
+ set_module_tensor_to_device (self .model , param_name , self .device , value , dtype = value . dtype )
725
725
# sub_layer.weight.data = Q
726
726
torch .save (sub_layer .state_dict (), LWQ_WORKSPACE + f"/{ full_layer_name } .pt" )
727
727
clean_module_weight (sub_layer )
@@ -745,7 +745,8 @@ def tmp(_, inp, out):
745
745
for j in range (len (self .dataloader )):
746
746
cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
747
747
cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
748
- transformer_block = transformer_block .to (cache_positional_batch [0 ].dtype )
748
+ # breakpoint()
749
+ # transformer_block = transformer_block.to(getattr(torch, self.model.config.torch_dtype))
749
750
out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )
750
751
out = self .track_hidden_states (out )
751
752
outs .append (out )
@@ -968,6 +969,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
968
969
if not static_groups :
969
970
if (i1 + i ) % groupsize == 0 :
970
971
self .quantizer .find_params (W [:, (i1 + i ) : (i1 + i + groupsize )], weight = True )
972
+ scale .append (self .quantizer .scale )
971
973
zero .append (self .quantizer .zero )
972
974
else :
973
975
idx = i1 + i
0 commit comments