File tree 1 file changed +3
-2
lines changed
1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -1072,14 +1072,15 @@ class LoraMixin:
1072
1072
LORA_B_PARAM_NAME = "lora_B"
1073
1073
1074
1074
def init_lora (self , lspec : PTLoraSpec ):
1075
+ self ._lspec = lspec
1075
1076
default_lora_dtype = torch .bfloat16
1076
1077
out_features , in_features = lspec .orig_weight_shape
1077
1078
rank = lspec .lora_rank
1078
1079
if rank > out_features or rank > in_features :
1079
1080
msg = f"Specified LoRA rank={ rank } cannot exceed any dimension of the weight tensor"
1080
1081
raise nncf .ValidationError (msg )
1081
- self ._lora_A = torch .nn .Parameter (torch .ones ((rank , in_features ), dtype = default_lora_dtype ))
1082
- self ._lora_B = torch .nn .Parameter (torch .zeros ((out_features , rank ), dtype = default_lora_dtype ))
1082
+ self .lora_A = torch .nn .Parameter (torch .ones ((rank , in_features ), dtype = default_lora_dtype ))
1083
+ self .lora_B = torch .nn .Parameter (torch .zeros ((out_features , rank ), dtype = default_lora_dtype ))
1083
1084
1084
1085
def enable_gradients (self ):
1085
1086
self .lora_A .requires_grad = True
You can’t perform that action at this time.
0 commit comments