@@ -718,10 +718,12 @@ def tmp(_, inp, out):
718
718
for n , p in sub_layer .named_parameters ():
719
719
param_name = full_layer_name + "." + n
720
720
if n == "weight" :
721
- set_module_tensor_to_device (self .model , param_name , self .device , Q )
721
+ set_module_tensor_to_device (self .model , param_name , self .device , Q , dtype = Q . dtype )
722
722
else :
723
723
value = load_value (self .model , param_name , model_path )
724
- set_module_tensor_to_device (self .model , param_name , self .device , value )
724
+ set_module_tensor_to_device (
725
+ self .model , param_name , self .device , value , dtype = value .dtype
726
+ )
725
727
# sub_layer.weight.data = Q
726
728
torch .save (sub_layer .state_dict (), LWQ_WORKSPACE + f"/{ full_layer_name } .pt" )
727
729
clean_module_weight (sub_layer )
@@ -745,6 +747,8 @@ def tmp(_, inp, out):
745
747
for j in range (len (self .dataloader )):
746
748
cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
747
749
cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
750
+ # breakpoint()
751
+ # transformer_block = transformer_block.to(getattr(torch, self.model.config.torch_dtype))
748
752
out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )
749
753
out = self .track_hidden_states (out )
750
754
outs .append (out )
0 commit comments