Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QAT Lora 5/N] Fixes for loading/saving compression checkpoint #3341

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,27 @@ def check_user_compression_configuration(
msg = f"The ratio should be between 0 and 1, but ratio={ratio} is specified."
raise nncf.ValidationError(msg)

if subset_size <= 0:
msg = f"The subset_size value should be positive, but subset_size={subset_size} is given."
raise nncf.ValidationError(msg)
values_to_check = [subset_size]
ranks = []
if advanced_parameters:
values_to_check.extend(
[
advanced_parameters.awq_params.subset_size,
advanced_parameters.scale_estimation_params.subset_size,
advanced_parameters.gptq_params.subset_size,
advanced_parameters.lora_correction_params.subset_size,
]
)
ranks = [advanced_parameters.lora_adapter_rank, advanced_parameters.lora_correction_params.adapter_rank]
for size in values_to_check:
if size <= 0:
msg = f"The subset_size value should be positive, but subset_size={size} is given."
raise nncf.ValidationError(msg)

for rank in ranks:
if rank <= 0:
msg = f"The lora adapter rank should be positive, but rank={rank} is given."
raise nncf.ValidationError(msg)

if (
ratio
Expand Down Expand Up @@ -656,6 +674,7 @@ def apply(
zero_points,
lora_correction_algo,
self._compression_format,
self._advanced_parameters,
)

self._backend_entity.dump_parameters(
Expand Down
25 changes: 23 additions & 2 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def signed(self, signed: bool):
self.set_levels()

def quantize(self, x, execute_traced_op_as_identity: bool = False):
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
with DisableTorchFunction():
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
self.to(x.device)
return symmetric_quantize(
x, self.levels, self.level_low, self.level_high, self.scale, self.eps, skip=execute_traced_op_as_identity
)
Expand Down Expand Up @@ -955,6 +959,10 @@ def set_levels(self):
self.level_low, self.level_high = calculate_asymmetric_level_ranges(self.num_bits - scaled_num_bits)

def quantize(self, x, execute_traced_op_as_identity: bool = False):
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
with DisableTorchFunction():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fore new tracing it's not works, no need to add this hack

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, it's needed to pass graph test for torch2

# in multi-device case after loading nncf checkpoint, quantizers have a different device.
self.to(x.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only as WA, need to be reworked before release,
please create ticket and add TODO comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO

return asymmetric_quantize(
x,
self.levels,
Expand Down Expand Up @@ -1067,9 +1075,14 @@ class LoraMixin:

def init_lora(self, lspec: PTLoraSpec):
self._lspec = lspec
default_lora_dtype = torch.bfloat16
out_features, in_features = lspec.orig_weight_shape
self.lora_A = torch.nn.Parameter(torch.ones((lspec.lora_rank, in_features), dtype=torch.bfloat16))
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, lspec.lora_rank), dtype=torch.bfloat16))
rank = lspec.lora_rank
if rank > out_features or rank > in_features:
msg = f"Specified LoRA rank={rank} cannot exceed any dimension of the weight tensor"
raise nncf.ValidationError(msg)
self.lora_A = torch.nn.Parameter(torch.ones((rank, in_features), dtype=default_lora_dtype))
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, rank), dtype=default_lora_dtype))

def enable_gradients(self):
self.lora_A.requires_grad = True
Expand Down Expand Up @@ -1097,6 +1110,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
self.init_lora(lspec)

def quantize(self, x: torch.Tensor, execute_traced_op_as_identity: bool = False):
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
with DisableTorchFunction():
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
self.to(x.device)
return asymmetric_quantize_lora(
x,
self._lspec.weight_shape,
Expand Down Expand Up @@ -1142,6 +1159,10 @@ def __init__(self, qspec: PTQuantizerSpec, lspec: PTLoraSpec):
self.init_lora(lspec)

def quantize(self, x, execute_traced_op_as_identity: bool = False):
# TODO: (dokuchaev) remove within new tracing (ticket-163869)
with DisableTorchFunction():
# in multi-device case after loading nncf checkpoint, quantizers have a different device.
self.to(x.device)
return symmetric_quantize_lora(
x,
self._lspec.weight_shape,
Expand Down
73 changes: 73 additions & 0 deletions tests/torch/ptq/test_fq_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from transformers import AutoTokenizer

import nncf
from nncf.data.dataset import Dataset
from nncf.errors import ValidationError
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.quantize_model import compress_weights
from nncf.scopes import IgnoredScope
from nncf.torch import load_from_config
from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
from nncf.torch.quantization.layers import LoraMixin
from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
from tests.torch.test_models.synthetic import LinearModel


@pytest.mark.parametrize(
Expand Down Expand Up @@ -80,3 +89,67 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable

assert first_loss > 8
assert float(loss) < 1


def test_checkpoint_loading(tmp_path):
model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
if not torch.cuda.is_available():
pytest.skip("Skipping CUDA test case for CPU only setups.")
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
example_input = tokenizer("dummy", return_tensors="pt").to(device)
except_lm_head_and_5th_vproj = (
r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$"
)
model = compress_weights(
model,
group_size=32,
mode=CompressWeightsMode.INT4_ASYM,
backup_mode=CompressWeightsMode.INT8_ASYM,
dataset=Dataset([dict(example_input)]),
compression_format=CompressionFormat.FQ_LORA,
ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]),
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2),
)
ref_output = tokenizer.decode(
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True
)

# save checkpoint
ckpt_path = tmp_path / "nncf_ckpt.pth"
torch.save(
{
"nncf_state_dict": model.nncf.state_dict(),
"nncf_config": model.nncf.get_config(),
},
ckpt_path,
)
del model

# load checkpoint
nncf_ckpt = torch.load(ckpt_path, weights_only=False)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input))
model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"])

actual_output = tokenizer.decode(
model.generate(**example_input, do_sample=False, max_new_tokens=20)[0],
skip_special_tokens=True,
)
assert actual_output == ref_output


def test_invalid_lora_rank():
too_big_rank = 4
model = LinearModel(torch.ones(2, 2))
with pytest.raises(ValidationError):
compress_weights(
model,
mode=CompressWeightsMode.INT4_ASYM,
group_size=2,
all_layers=True,
dataset=Dataset([torch.ones(2, 2)]),
compression_format=CompressionFormat.FQ_LORA,
advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank),
)
11 changes: 1 addition & 10 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nncf.torch.quantization.quantize_functions import unpack_int4
from nncf.torch.quantization.quantize_functions import unpack_uint4
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
from tests.torch.test_models.synthetic import LinearModel
from tests.torch.test_models.synthetic import ShortTransformer
from tests.torch.test_tensor import cast_to

Expand Down Expand Up @@ -82,16 +83,6 @@ def forward(self, input):
return input @ self.w


class LinearModel(torch.nn.Module):
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
super().__init__()
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
self.linear.weight = torch.nn.Parameter(weight)

def forward(self, input):
return self.linear(input)


class AWQActLinearModel(nn.Module):
def __init__(self, with_multiply=False, n_layers=8):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/quantization/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_quantizer_layers_accepts_return_type(registred):
)
if mode in [QuantizationMode.ASYMMETRIC_LORA, QuantizationMode.SYMMETRIC_LORA]:
shape = actual_input.unsqueeze(dim=0).shape
lora_spec = PTLoraSpec(2, shape, shape)
lora_spec = PTLoraSpec(0, shape, shape)
quantizer = quantizer_cls(quantizer_spec, lora_spec)
else:
quantizer = quantizer_cls(quantizer_spec)
Expand Down
10 changes: 10 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,16 @@ def forward(self, input_ids):
return res


class LinearModel(torch.nn.Module):
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
super().__init__()
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
self.linear.weight = torch.nn.Parameter(weight)

def forward(self, input):
return self.linear(input)


class YOLO11N_SDPABlock(torch.nn.Module):
INPUT_SIZE = (1, 2, 4)

Expand Down