Skip to content

Commit 69df6d8

Browse files
committed
Fixes for loading/saving compression checkpoint
1 parent 5f4378e commit 69df6d8

File tree

5 files changed

+128
-16
lines changed

5 files changed

+128
-16
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,27 @@ def check_user_compression_configuration(
160160
msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
161161
raise nncf.ValidationError(msg)
162162

163-
if subset_size <= 0:
164-
msg = f"The subset_size value should be positive, but subset_size={subset_size} is given."
165-
raise nncf.ValidationError(msg)
163+
values_to_check = [subset_size]
164+
ranks = []
165+
if advanced_parameters:
166+
values_to_check.extend(
167+
[
168+
advanced_parameters.awq_params.subset_size,
169+
advanced_parameters.scale_estimation_params.subset_size,
170+
advanced_parameters.gptq_params.subset_size,
171+
advanced_parameters.lora_correction_params.subset_size,
172+
]
173+
)
174+
ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]
175+
for size in values_to_check:
176+
if size <= 0:
177+
msg = f"The subset_size value should be positive, but subset_size={size} is given."
178+
raise nncf.ValidationError(msg)
179+
180+
for rank in ranks:
181+
if rank <= 0:
182+
msg = f"The lora adapter rank should be positive, but rank={rank} is given."
183+
raise nncf.ValidationError(msg)
166184

167185
if (
168186
ratio
@@ -656,6 +674,7 @@ def apply(
656674
zero_points,
657675
lora_correction_algo,
658676
self._compression_format,
677+
self._advanced_parameters,
659678
)
660679

661680
self._backend_entity.dump_parameters(

nncf/torch/quantization/layers.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,9 @@ def signed(self, signed: bool):
768768
self.set_levels()
769769

770770
def quantize(self, x, execute_traced_op_as_identity: bool = False):
771+
with DisableTorchFunction():
772+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
773+
self.to(x.device)
771774
return symmetric_quantize(
772775
x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity
773776
)
@@ -955,6 +958,9 @@ def set_levels(self):
955958
self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits)
956959

957960
def quantize(self, x, execute_traced_op_as_identity: bool = False):
961+
with DisableTorchFunction():
962+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
963+
self.to(x.device)
958964
return asymmetric_quantize(
959965
x,
960966
self.levels,
@@ -1066,10 +1072,14 @@ class LoraMixin:
10661072
LORA_B_PARAM_NAME = "lora_B"
10671073

10681074
def init_lora(self, lspec: PTLoraSpec):
1069-
self._lspec = lspec
1075+
default_lora_dtype = torch.bfloat16
10701076
out_features, in_features = lspec.orig_weight_shape
1071-
self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16))
1072-
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16))
1077+
rank = lspec.lora_rank
1078+
if rank > out_features or rank > in_features:
1079+
msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor"
1080+
raise nncf.ValidationError(msg)
1081+
self._lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype))
1082+
self._lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype))
10731083

10741084
def enable_gradients(self):
10751085
self.lora_A.requires_grad = True
@@ -1097,6 +1107,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
10971107
self.init_lora(lspec)
10981108

10991109
def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False):
1110+
with DisableTorchFunction():
1111+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
1112+
self.to(x.device)
11001113
return asymmetric_quantize_lora(
11011114
x,
11021115
self._lspec.weight_shape,
@@ -1142,6 +1155,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
11421155
self.init_lora(lspec)
11431156

11441157
def quantize(self, x, execute_traced_op_as_identity: bool = False):
1158+
with DisableTorchFunction():
1159+
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
1160+
self.to(x.device)
11451161
return symmetric_quantize_lora(
11461162
x,
11471163
self._lspec.weight_shape,

tests/torch/ptq/test_fq_lora.py

+76
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@
1515
from transformers import AutoTokenizer
1616

1717
import nncf
18+
from nncf.data.dataset import Dataset
19+
from nncf.errors import ValidationError
20+
from nncf.parameters import CompressionFormat
21+
from nncf.parameters import CompressWeightsMode
22+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
23+
from nncf.quantization.quantize_model import compress_weights
24+
from nncf.scopes import IgnoredScope
25+
from nncf.torch import load_from_config
1826
from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
1927
from nncf.torch.quantization.layers import LoraMixin
2028
from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
29+
from tests.torch.test_models.synthetic import LinearModel
2130

2231

2332
@pytest.mark.parametrize(
@@ -80,3 +89,70 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
8089

8190
assert first_loss > 8
8291
assert float(loss) < 1
92+
93+
94+
def test_checkpoint_loading(tmp_path):
95+
model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
96+
if not torch.cuda.is_available():
97+
pytest.skip("Skipping CUDA test case for CPU only setups.")
98+
device = "cuda"
99+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
100+
tokenizer = AutoTokenizer.from_pretrained(model_id)
101+
example_input = tokenizer("dummy", return_tensors="pt").to(device)
102+
ref_output = tokenizer.decode(
103+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
104+
)
105+
except_lm_head_and_5th_vproj = (
106+
r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$"
107+
)
108+
model = compress_weights(
109+
model,
110+
group_size=32,
111+
mode=CompressWeightsMode.INT4_ASYM,
112+
backup_mode=CompressWeightsMode.INT8_ASYM,
113+
dataset=Dataset([dict(example_input)]),
114+
compression_format=CompressionFormat.FQ_LORA,
115+
ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]),
116+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2),
117+
)
118+
ref_output = tokenizer.decode(
119+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
120+
)
121+
122+
# save checkpoint
123+
ckpt_path = tmp_path / "nncf_ckpt.pth"
124+
torch.save(
125+
{
126+
"nncf_state_dict": model.nncf.state_dict(),
127+
"nncf_config": model.nncf.get_config(),
128+
},
129+
ckpt_path,
130+
)
131+
del model
132+
133+
# load checkpoint
134+
nncf_ckpt = torch.load(ckpt_path, weights_only=False)
135+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
136+
model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input))
137+
model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"])
138+
139+
actual_output = tokenizer.decode(
140+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0],
141+
skip_special_tokens=True,
142+
)
143+
assert actual_output == ref_output
144+
145+
146+
def test_invalid_lora_rank():
147+
too_big_rank = 4
148+
model = LinearModel(torch.ones(2, 2))
149+
with pytest.raises(ValidationError):
150+
compress_weights(
151+
model,
152+
mode=CompressWeightsMode.INT4_ASYM,
153+
group_size=2,
154+
all_layers=True,
155+
dataset=Dataset([torch.ones(2, 2)]),
156+
compression_format=CompressionFormat.FQ_LORA,
157+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank),
158+
)

tests/torch/ptq/test_weights_compression.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from nncf.torch.quantization.quantize_functions import unpack_int4
4040
from nncf.torch.quantization.quantize_functions import unpack_uint4
4141
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
42+
from tests.torch.test_models.synthetic import LinearModel
4243
from tests.torch.test_models.synthetic import ShortTransformer
4344
from tests.torch.test_tensor import cast_to
4445

@@ -82,16 +83,6 @@ def forward(self, input):
8283
return input @ self.w
8384

8485

85-
class LinearModel(torch.nn.Module):
86-
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
87-
super().__init__()
88-
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
89-
self.linear.weight = torch.nn.Parameter(weight)
90-
91-
def forward(self, input):
92-
return self.linear(input)
93-
94-
9586
class AWQActLinearModel(nn.Module):
9687
def __init__(self, with_multiply=False, n_layers=8):
9788
super().__init__()

tests/torch/test_models/synthetic.py

+10
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,16 @@ def forward(self, input_ids):
660660
return res
661661

662662

663+
class LinearModel(torch.nn.Module):
664+
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
665+
super().__init__()
666+
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
667+
self.linear.weight = torch.nn.Parameter(weight)
668+
669+
def forward(self, input):
670+
return self.linear(input)
671+
672+
663673
class YOLO11N_SDPABlock(torch.nn.Module):
664674
INPUT_SIZE = (1, 2, 4)
665675

0 commit comments

Comments
 (0)