Skip to content

Commit d840360

Browse files
committed
no wa for device in fq, test for lora and chkp load
1 parent 6899d77 commit d840360

File tree

6 files changed

+125
-15
lines changed

6 files changed

+125
-15
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,21 @@ def check_user_compression_configuration(
160160
msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
161161
raise nncf.ValidationError(msg)
162162

163-
if subset_size <= 0:
164-
msg = f"The subset_size value should be positive, but subset_size={subset_size} is given."
165-
raise nncf.ValidationError(msg)
163+
for size in [
164+
subset_size,
165+
advanced_parameters.awq_params.subset_size,
166+
advanced_parameters.scale_estimation_params.subset_size,
167+
advanced_parameters.gptq_params.subset_size,
168+
advanced_parameters.lora_correction_params.subset_size,
169+
]:
170+
if size <= 0:
171+
msg = f"The subset_size value should be positive, but subset_size={size} is given."
172+
raise nncf.ValidationError(msg)
173+
174+
for rank in [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]:
175+
if rank <= 0:
176+
msg = f"The lora adapter rank should be positive, but rank={rank} is given."
177+
raise nncf.ValidationError(msg)
166178

167179
if (
168180
ratio
@@ -656,6 +668,7 @@ def apply(
656668
zero_points,
657669
lora_correction_algo,
658670
self._compression_format,
671+
self._advanced_parameters,
659672
)
660673

661674
self._backend_entity.dump_parameters(

nncf/torch/model_transformer.py

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from nncf.torch.graph.transformations.commands import PTTargetPoint
3030
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
3131
from nncf.torch.graph.transformations.layout import PTTransformationLayout
32+
from nncf.torch.model_graph_manager import get_const_data
33+
from nncf.torch.model_graph_manager import get_const_node
3234
from nncf.torch.model_graph_manager import update_fused_bias
3335
from nncf.torch.module_operations import UpdateWeight
3436
from nncf.torch.nncf_network import NNCFNetwork
@@ -74,6 +76,19 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
7476
model = transformation_fn(model, transformations)
7577

7678
if requires_graph_rebuild:
79+
graph = model.nncf.get_original_graph()
80+
for command in transformation_layout.transformations:
81+
compression_module = command.fn
82+
if isinstance(compression_module, nn.Module):
83+
target_point = command.target_points[0]
84+
node_with_weight = graph.get_node_by_name(target_point.target_node_name)
85+
weight_node = get_const_node(node_with_weight, target_point.input_port_id, graph)
86+
if weight_node is None:
87+
weight_node = node_with_weight # Decompression in DQ compression format is applied to const.
88+
const_data = get_const_data(weight_node, model)
89+
# Compression module and the corresponding layer may have a different device in multi-device setup
90+
# (e.g. when HF model was loaded with device_map='auto'). Need to align devices.
91+
compression_module.to(const_data.device)
7792
model.nncf.rebuild_graph()
7893

7994
return model

nncf/torch/quantization/layers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1067,9 +1067,14 @@ class LoraMixin:
10671067

10681068
def __init__(self, lspec: PTLoraSpec):
10691069
self._lspec = lspec
1070+
default_lora_dtype = torch.bfloat16
10701071
out_features, in_features = lspec.orig_weight_shape
1071-
self._lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16))
1072-
self._lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16))
1072+
rank = lspec.lora_rank
1073+
if rank > out_features or rank > in_features:
1074+
msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor"
1075+
raise nncf.ValidationError(msg)
1076+
self._lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype))
1077+
self._lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype))
10731078

10741079
def enable_gradients(self):
10751080
self._lora_A.requires_grad = True

tests/torch/ptq/test_fq_lora.py

+76
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@
1515
from transformers import AutoTokenizer
1616

1717
import nncf
18+
from nncf.data.dataset import Dataset
19+
from nncf.errors import ValidationError
20+
from nncf.parameters import CompressionFormat
21+
from nncf.parameters import CompressWeightsMode
22+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
23+
from nncf.quantization.quantize_model import compress_weights
24+
from nncf.scopes import IgnoredScope
25+
from nncf.torch import load_from_config
1826
from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
1927
from nncf.torch.quantization.layers import LoraMixin
2028
from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
29+
from tests.torch.test_models.synthetic import LinearModel
2130

2231

2332
@pytest.mark.parametrize(
@@ -80,3 +89,70 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
8089

8190
assert first_loss > 8
8291
assert float(loss) < 1
92+
93+
94+
def test_checkpoint_loading(tmp_path):
95+
model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
96+
if not torch.cuda.is_available():
97+
pytest.skip("Skipping CUDA test case for CPU only setups.")
98+
device = "cuda"
99+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
100+
tokenizer = AutoTokenizer.from_pretrained(model_id)
101+
example_input = tokenizer("dummy", return_tensors="pt").to(device)
102+
ref_output = tokenizer.decode(
103+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
104+
)
105+
except_lm_head_and_5th_vproj = (
106+
r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$"
107+
)
108+
model = compress_weights(
109+
model,
110+
group_size=32,
111+
mode=CompressWeightsMode.INT4_ASYM,
112+
backup_mode=CompressWeightsMode.INT8_ASYM,
113+
dataset=Dataset([dict(example_input)]),
114+
compression_format=CompressionFormat.FQ_LORA,
115+
ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]),
116+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2),
117+
)
118+
ref_output = tokenizer.decode(
119+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
120+
)
121+
122+
# save checkpoint
123+
ckpt_path = tmp_path / "nncf_ckpt.pth"
124+
torch.save(
125+
{
126+
"nncf_state_dict": model.nncf.state_dict(),
127+
"nncf_config": model.nncf.get_config(),
128+
},
129+
ckpt_path,
130+
)
131+
del model
132+
133+
# load checkpoint
134+
nncf_ckpt = torch.load(ckpt_path, weights_only=False)
135+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
136+
model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input))
137+
model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"])
138+
139+
actual_output = tokenizer.decode(
140+
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0],
141+
skip_special_tokens=True,
142+
)
143+
assert actual_output == ref_output
144+
145+
146+
def test_invalid_lora_rank():
147+
too_big_rank = 4
148+
model = LinearModel(torch.ones(2, 2))
149+
with pytest.raises(ValidationError):
150+
compress_weights(
151+
model,
152+
mode=CompressWeightsMode.INT4_ASYM,
153+
group_size=2,
154+
all_layers=True,
155+
dataset=Dataset([torch.ones(2, 2)]),
156+
compression_format=CompressionFormat.FQ_LORA,
157+
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank),
158+
)

tests/torch/ptq/test_weights_compression.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from nncf.torch.quantization.quantize_functions import unpack_int4
4040
from nncf.torch.quantization.quantize_functions import unpack_uint4
4141
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
42+
from tests.torch.test_models.synthetic import LinearModel
4243
from tests.torch.test_models.synthetic import ShortTransformer
4344
from tests.torch.test_tensor import cast_to
4445

@@ -82,16 +83,6 @@ def forward(self, input):
8283
return input @ self.w
8384

8485

85-
class LinearModel(torch.nn.Module):
86-
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
87-
super().__init__()
88-
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
89-
self.linear.weight = torch.nn.Parameter(weight)
90-
91-
def forward(self, input):
92-
return self.linear(input)
93-
94-
9586
class AWQActLinearModel(nn.Module):
9687
def __init__(self, with_multiply=False, n_layers=8):
9788
super().__init__()

tests/torch/test_models/synthetic.py

+10
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,16 @@ def forward(self, input_ids):
660660
return res
661661

662662

663+
class LinearModel(torch.nn.Module):
664+
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
665+
super().__init__()
666+
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
667+
self.linear.weight = torch.nn.Parameter(weight)
668+
669+
def forward(self, input):
670+
return self.linear(input)
671+
672+
663673
class YOLO11N_SDPABlock(torch.nn.Module):
664674
INPUT_SIZE = (1, 2, 4)
665675

0 commit comments

Comments
 (0)