From 69df6d8764e8b4b6e9886b50d962754bf7783f68 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Fri, 14 Mar 2025 18:11:54 +0100 Subject: [PATCH 1/5] Fixes for loading/saving compression checkpoint --- .../weight_compression/algorithm.py | 25 +++++- nncf/torch/quantization/layers.py | 22 +++++- tests/torch/ptq/test_fq_lora.py | 76 +++++++++++++++++++ tests/torch/ptq/test_weights_compression.py | 11 +-- tests/torch/test_models/synthetic.py | 10 +++ 5 files changed, 128 insertions(+), 16 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 6169f4bada3..d484a7bd4ee 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -160,9 +160,27 @@ def check_user_compression_configuration( msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified." raise nncf.ValidationError(msg) - if subset_size <= 0: - msg = f"The subset_size value should be positive, but subset_size={subset_size} is given." - raise nncf.ValidationError(msg) + values_to_check = [subset_size] + ranks = [] + if advanced_parameters: + values_to_check.extend( + [ + advanced_parameters.awq_params.subset_size, + advanced_parameters.scale_estimation_params.subset_size, + advanced_parameters.gptq_params.subset_size, + advanced_parameters.lora_correction_params.subset_size, + ] + ) + ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank] + for size in values_to_check: + if size <= 0: + msg = f"The subset_size value should be positive, but subset_size={size} is given." + raise nncf.ValidationError(msg) + + for rank in ranks: + if rank <= 0: + msg = f"The lora adapter rank should be positive, but rank={rank} is given." + raise nncf.ValidationError(msg) if ( ratio @@ -656,6 +674,7 @@ def apply( zero_points, lora_correction_algo, self._compression_format, + self._advanced_parameters, ) self._backend_entity.dump_parameters( diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 79c12868d0a..35e9966614b 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -768,6 +768,9 @@ def signed(self, signed: bool): self.set_levels() def quantize(self, x, execute_traced_op_as_identity: bool = False): + with DisableTorchFunction(): + # in multi-device case after loading nncf checkpoint, quantizers have a different device. + self.to(x.device) return symmetric_quantize( x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity ) @@ -955,6 +958,9 @@ def set_levels(self): self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits) def quantize(self, x, execute_traced_op_as_identity: bool = False): + with DisableTorchFunction(): + # in multi-device case after loading nncf checkpoint, quantizers have a different device. + self.to(x.device) return asymmetric_quantize( x, self.levels, @@ -1066,10 +1072,14 @@ class LoraMixin: LORA_B_PARAM_NAME = "lora_B" def init_lora(self, lspec: PTLoraSpec): - self._lspec = lspec + default_lora_dtype = torch.bfloat16 out_features, in_features = lspec.orig_weight_shape - self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16)) - self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16)) + rank = lspec.lora_rank + if rank > out_features or rank > in_features: + msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor" + raise nncf.ValidationError(msg) + self._lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype)) + self._lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype)) def enable_gradients(self): self.lora_A.requires_grad = True @@ -1097,6 +1107,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): self.init_lora(lspec) def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False): + with DisableTorchFunction(): + # in multi-device case after loading nncf checkpoint, quantizers have a different device. + self.to(x.device) return asymmetric_quantize_lora( x, self._lspec.weight_shape, @@ -1142,6 +1155,9 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): self.init_lora(lspec) def quantize(self, x, execute_traced_op_as_identity: bool = False): + with DisableTorchFunction(): + # in multi-device case after loading nncf checkpoint, quantizers have a different device. + self.to(x.device) return symmetric_quantize_lora( x, self._lspec.weight_shape, diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py index 327733c95b8..e733dc8c527 100644 --- a/tests/torch/ptq/test_fq_lora.py +++ b/tests/torch/ptq/test_fq_lora.py @@ -15,9 +15,18 @@ from transformers import AutoTokenizer import nncf +from nncf.data.dataset import Dataset +from nncf.errors import ValidationError +from nncf.parameters import CompressionFormat +from nncf.parameters import CompressWeightsMode +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters +from nncf.quantization.quantize_model import compress_weights +from nncf.scopes import IgnoredScope +from nncf.torch import load_from_config from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ from nncf.torch.quantization.layers import LoraMixin from nncf.torch.quantization.layers import SymmetricQuantizer as SQ +from tests.torch.test_models.synthetic import LinearModel @pytest.mark.parametrize( @@ -80,3 +89,70 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable assert first_loss > 8 assert float(loss) < 1 + + +def test_checkpoint_loading(tmp_path): + model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA test case for CPU only setups.") + device = "cuda" + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_id) + example_input = tokenizer("dummy", return_tensors="pt").to(device) + ref_output = tokenizer.decode( + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True + ) + except_lm_head_and_5th_vproj = ( + r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$" + ) + model = compress_weights( + model, + group_size=32, + mode=CompressWeightsMode.INT4_ASYM, + backup_mode=CompressWeightsMode.INT8_ASYM, + dataset=Dataset([dict(example_input)]), + compression_format=CompressionFormat.FQ_LORA, + ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]), + advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2), + ) + ref_output = tokenizer.decode( + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True + ) + + # save checkpoint + ckpt_path = tmp_path / "nncf_ckpt.pth" + torch.save( + { + "nncf_state_dict": model.nncf.state_dict(), + "nncf_config": model.nncf.get_config(), + }, + ckpt_path, + ) + del model + + # load checkpoint + nncf_ckpt = torch.load(ckpt_path, weights_only=False) + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") + model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input)) + model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"]) + + actual_output = tokenizer.decode( + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], + skip_special_tokens=True, + ) + assert actual_output == ref_output + + +def test_invalid_lora_rank(): + too_big_rank = 4 + model = LinearModel(torch.ones(2, 2)) + with pytest.raises(ValidationError): + compress_weights( + model, + mode=CompressWeightsMode.INT4_ASYM, + group_size=2, + all_layers=True, + dataset=Dataset([torch.ones(2, 2)]), + compression_format=CompressionFormat.FQ_LORA, + advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank), + ) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index b71c4179ca3..7a9455c7be4 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -39,6 +39,7 @@ from nncf.torch.quantization.quantize_functions import unpack_int4 from nncf.torch.quantization.quantize_functions import unpack_uint4 from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression +from tests.torch.test_models.synthetic import LinearModel from tests.torch.test_models.synthetic import ShortTransformer from tests.torch.test_tensor import cast_to @@ -82,16 +83,6 @@ def forward(self, input): return input @ self.w -class LinearModel(torch.nn.Module): - def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)): - super().__init__() - self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False) - self.linear.weight = torch.nn.Parameter(weight) - - def forward(self, input): - return self.linear(input) - - class AWQActLinearModel(nn.Module): def __init__(self, with_multiply=False, n_layers=8): super().__init__() diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index e64a234d1c6..c6b9d9678c5 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -660,6 +660,16 @@ def forward(self, input_ids): return res +class LinearModel(torch.nn.Module): + def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)): + super().__init__() + self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False) + self.linear.weight = torch.nn.Parameter(weight) + + def forward(self, input): + return self.linear(input) + + class YOLO11N_SDPABlock(torch.nn.Module): INPUT_SIZE = (1, 2, 4) From a15ca14dac07e8dac816a819b22262adf448d42f Mon Sep 17 00:00:00 2001 From: Nikolay Date: Fri, 14 Mar 2025 19:45:02 +0100 Subject: [PATCH 2/5] typo --- nncf/torch/quantization/layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 35e9966614b..abccb8062fa 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -1072,14 +1072,15 @@ class LoraMixin: LORA_B_PARAM_NAME = "lora_B" def init_lora(self, lspec: PTLoraSpec): + self._lspec = lspec default_lora_dtype = torch.bfloat16 out_features, in_features = lspec.orig_weight_shape rank = lspec.lora_rank if rank > out_features or rank > in_features: msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor" raise nncf.ValidationError(msg) - self._lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype)) - self._lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype)) + self.lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype)) + self.lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype)) def enable_gradients(self): self.lora_A.requires_grad = True From 4042607cbea4fdf16e343624c78af7a4d612c48b Mon Sep 17 00:00:00 2001 From: Nikolay Date: Fri, 14 Mar 2025 23:38:00 +0100 Subject: [PATCH 3/5] fixed lora rank test --- tests/torch/quantization/test_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch/quantization/test_layers.py b/tests/torch/quantization/test_layers.py index 97cbb4c41ab..e42a2e88c7b 100644 --- a/tests/torch/quantization/test_layers.py +++ b/tests/torch/quantization/test_layers.py @@ -36,7 +36,7 @@ def test_quantizer_layers_accepts_return_type(registred): ) if mode in [QuantizationMode.ASYMMETRIC_LORA, QuantizationMode.SYMMETRIC_LORA]: shape = actual_input.unsqueeze(dim=0).shape - lora_spec = PTLoraSpec(2, shape, shape) + lora_spec = PTLoraSpec(0, shape, shape) quantizer = quantizer_cls(quantizer_spec, lora_spec) else: quantizer = quantizer_cls(quantizer_spec) From 9360c0f0652a53a5d6b36d5c387e82835b9da105 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Mon, 17 Mar 2025 21:59:27 +0100 Subject: [PATCH 4/5] removed unused line --- tests/torch/ptq/test_fq_lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py index e733dc8c527..a6c667b6823 100644 --- a/tests/torch/ptq/test_fq_lora.py +++ b/tests/torch/ptq/test_fq_lora.py @@ -99,9 +99,6 @@ def test_checkpoint_loading(tmp_path): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) example_input = tokenizer("dummy", return_tensors="pt").to(device) - ref_output = tokenizer.decode( - model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True - ) except_lm_head_and_5th_vproj = ( r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$" ) From 2c071ed3315b0f3651799b8f5931005176e716bf Mon Sep 17 00:00:00 2001 From: Nikolay Date: Wed, 19 Mar 2025 17:10:06 +0100 Subject: [PATCH 5/5] added todo to fix wa for new tracing --- nncf/torch/quantization/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index abccb8062fa..46629bc380e 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -768,6 +768,7 @@ def signed(self, signed: bool): self.set_levels() def quantize(self, x, execute_traced_op_as_identity: bool = False): + # TODO: (dokuchaev) remove within new tracing (ticket-163869) with DisableTorchFunction(): # in multi-device case after loading nncf checkpoint, quantizers have a different device. self.to(x.device) @@ -958,6 +959,7 @@ def set_levels(self): self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits) def quantize(self, x, execute_traced_op_as_identity: bool = False): + # TODO: (dokuchaev) remove within new tracing (ticket-163869) with DisableTorchFunction(): # in multi-device case after loading nncf checkpoint, quantizers have a different device. self.to(x.device) @@ -1108,6 +1110,7 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): self.init_lora(lspec) def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False): + # TODO: (dokuchaev) remove within new tracing (ticket-163869) with DisableTorchFunction(): # in multi-device case after loading nncf checkpoint, quantizers have a different device. self.to(x.device) @@ -1156,6 +1159,7 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec): self.init_lora(lspec) def quantize(self, x, execute_traced_op_as_identity: bool = False): + # TODO: (dokuchaev) remove within new tracing (ticket-163869) with DisableTorchFunction(): # in multi-device case after loading nncf checkpoint, quantizers have a different device. self.to(x.device)