Skip to content

Commit 98b14c2

Browse files
authored
[QAT Lora 5/N] Fixes for loading/saving compression checkpoint (#3341)
### Changes Added test for loading/saving compression checkpoint. Fixed revealed errors with different device and invalid lora rank. ### Reason for changes Compression module and the corresponding layer may have a different device in multi-device setup (e.g. when HF model was loaded with device_map='auto'). Need to align devices. ### Related tickets 154907 ### Tests - test_checkpoint_loading - test_invalid_lora_rank
1 parent 97a3a3a commit 98b14c2

File tree

6 files changed

+130
-16
lines changed

6 files changed

+130
-16
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,27 @@ def check_user_compression_configuration(
167167
msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
168168
raise nncf.ValidationError(msg)
169169

170-
if subset_size <= 0:
171-
msg = f"The subset_size value should be positive, but subset_size={subset_size} is given."
172-
raise nncf.ValidationError(msg)
170+
values_to_check = [subset_size]
171+
ranks = []
172+
if advanced_parameters:
173+
values_to_check.extend(
174+
[
175+
advanced_parameters.awq_params.subset_size,
176+
advanced_parameters.scale_estimation_params.subset_size,
177+
advanced_parameters.gptq_params.subset_size,
178+
advanced_parameters.lora_correction_params.subset_size,
179+
]
180+
)
181+
ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]
182+
for size in values_to_check:
183+
if size <= 0:
184+
msg = f"The subset_size value should be positive, but subset_size={size} is given."
185+
raise nncf.ValidationError(msg)
186+
187+
for rank in ranks:
188+
if rank <= 0:
189+
msg = f"The lora adapter rank should be positive, but rank={rank} is given."
190+
raise nncf.ValidationError(msg)
173191

174192
if (
175193
ratio
@@ -663,6 +681,7 @@ def apply(
663681
zero_points,
664682
lora_correction_algo,
665683
self._compression_format,
684+
self._advanced_parameters,
666685
)
667686

668687
self._backend_entity.dump_parameters(

nncf/torch/quantization/layers.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,10 @@ def signed(self, signed: bool):
768768
self.set_levels()
769769

770770
def quantize(self, x, execute_traced_op_as_identity: bool = False):
771+
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
772+
with DisableTorchFunction():
773+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
774+
self.to(x.device)
771775
return symmetric_quantize(
772776
x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity
773777
)
@@ -955,6 +959,10 @@ def set_levels(self):
955959
self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits)
956960

957961
def quantize(self, x, execute_traced_op_as_identity: bool = False):
962+
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
963+
with DisableTorchFunction():
964+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
965+
self.to(x.device)
958966
return asymmetric_quantize(
959967
x,
960968
self.levels,
@@ -1067,9 +1075,14 @@ class LoraMixin:
10671075

10681076
def init_lora(self, lspec: PTLoraSpec):
10691077
self._lspec = lspec
1078+
default_lora_dtype = torch.bfloat16
10701079
out_features, in_features = lspec.orig_weight_shape
1071-
self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16))
1072-
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16))
1080+
rank = lspec.lora_rank
1081+
if rank > out_features or rank > in_features:
1082+
msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor"
1083+
raise nncf.ValidationError(msg)
1084+
self.lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype))
1085+
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype))
10731086

10741087
def enable_gradients(self):
10751088
self.lora_A.requires_grad = True
@@ -1097,6 +1110,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
10971110
self.init_lora(lspec)
10981111

10991112
def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False):
1113+
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
1114+
with DisableTorchFunction():
1115+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
1116+
self.to(x.device)
11001117
return asymmetric_quantize_lora(
11011118
x,
11021119
self._lspec.weight_shape,
@@ -1142,6 +1159,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
11421159
self.init_lora(lspec)
11431160

11441161
def quantize(self, x, execute_traced_op_as_identity: bool = False):
1162+
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
1163+
with DisableTorchFunction():
1164+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
1165+
self.to(x.device)
11451166
return symmetric_quantize_lora(
11461167
x,
11471168
self._lspec.weight_shape,

tests/torch/ptq/test_fq_lora.py

+73
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@
1515
from transformers import AutoTokenizer
1616

1717
import nncf
18+
from nncf.data.dataset import Dataset
19+
from nncf.errors import ValidationError
20+
from nncf.parameters import CompressionFormat
21+
from nncf.parameters import CompressWeightsMode
22+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
23+
from nncf.quantization.quantize_model import compress_weights
24+
from nncf.scopes import IgnoredScope
25+
from nncf.torch import load_from_config
1826
from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
1927
from nncf.torch.quantization.layers import LoraMixin
2028
from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
29+
from tests.torch.test_models.synthetic import LinearModel
2130

2231

2332
@pytest.mark.parametrize(
@@ -80,3 +89,67 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
8089

8190
assert first_loss > 8
8291
assert float(loss) < 1
92+
93+
94+
def test_checkpoint_loading(tmp_path):
95+
model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
96+
if not torch.cuda.is_available():
97+
pytest.skip("Skipping CUDA test case for CPU only setups.")
98+
device = "cuda"
99+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
100+
tokenizer = AutoTokenizer.from_pretrained(model_id)
101+
example_input = tokenizer("dummy", return_tensors="pt").to(device)
102+
except_lm_head_and_5th_vproj = (
103+
r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$"
104+
)
105+
model = compress_weights(
106+
model,
107+
group_size=32,
108+
mode=CompressWeightsMode.INT4_ASYM,
109+
backup_mode=CompressWeightsMode.INT8_ASYM,
110+
dataset=Dataset([dict(example_input)]),
111+
compression_format=CompressionFormat.FQ_LORA,
112+
ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]),
113+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2),
114+
)
115+
ref_output = tokenizer.decode(
116+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
117+
)
118+
119+
# save checkpoint
120+
ckpt_path = tmp_path / "nncf_ckpt.pth"
121+
torch.save(
122+
{
123+
"nncf_state_dict": model.nncf.state_dict(),
124+
"nncf_config": model.nncf.get_config(),
125+
},
126+
ckpt_path,
127+
)
128+
del model
129+
130+
# load checkpoint
131+
nncf_ckpt = torch.load(ckpt_path, weights_only=False)
132+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
133+
model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input))
134+
model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"])
135+
136+
actual_output = tokenizer.decode(
137+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0],
138+
skip_special_tokens=True,
139+
)
140+
assert actual_output == ref_output
141+
142+
143+
def test_invalid_lora_rank():
144+
too_big_rank = 4
145+
model = LinearModel(torch.ones(2, 2))
146+
with pytest.raises(ValidationError):
147+
compress_weights(
148+
model,
149+
mode=CompressWeightsMode.INT4_ASYM,
150+
group_size=2,
151+
all_layers=True,
152+
dataset=Dataset([torch.ones(2, 2)]),
153+
compression_format=CompressionFormat.FQ_LORA,
154+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank),
155+
)

tests/torch/ptq/test_weights_compression.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from nncf.torch.quantization.quantize_functions import unpack_int4
4040
from nncf.torch.quantization.quantize_functions import unpack_uint4
4141
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
42+
from tests.torch.test_models.synthetic import LinearModel
4243
from tests.torch.test_models.synthetic import ShortTransformer
4344
from tests.torch.test_tensor import cast_to
4445

@@ -82,16 +83,6 @@ def forward(self, input):
8283
return input @ self.w
8384

8485

85-
class LinearModel(torch.nn.Module):
86-
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
87-
super().__init__()
88-
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
89-
self.linear.weight = torch.nn.Parameter(weight)
90-
91-
def forward(self, input):
92-
return self.linear(input)
93-
94-
9586
class AWQActLinearModel(nn.Module):
9687
def __init__(self, with_multiply=False, n_layers=8):
9788
super().__init__()

tests/torch/quantization/test_layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_quantizer_layers_accepts_return_type(registred):
3636
)
3737
if mode in [QuantizationMode.ASYMMETRIC_LORA, QuantizationMode.SYMMETRIC_LORA]:
3838
shape = actual_input.unsqueeze(dim=0).shape
39-
lora_spec = PTLoraSpec(2, shape, shape)
39+
lora_spec = PTLoraSpec(0, shape, shape)
4040
quantizer = quantizer_cls(quantizer_spec, lora_spec)
4141
else:
4242
quantizer = quantizer_cls(quantizer_spec)

tests/torch/test_models/synthetic.py

+10
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,16 @@ def forward(self, input_ids):
660660
return res
661661

662662

663+
class LinearModel(torch.nn.Module):
664+
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
665+
super().__init__()
666+
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
667+
self.linear.weight = torch.nn.Parameter(weight)
668+
669+
def forward(self, input):
670+
return self.linear(input)
671+
672+
663673
class YOLO11N_SDPABlock(torch.nn.Module):
664674
INPUT_SIZE = (1, 2, 4)
665675

0 commit comments

Comments
 (0)