|
15 | 15 | from transformers import AutoTokenizer
|
16 | 16 |
|
17 | 17 | import nncf
|
| 18 | +from nncf.data.dataset import Dataset |
| 19 | +from nncf.errors import ValidationError |
| 20 | +from nncf.parameters import CompressionFormat |
| 21 | +from nncf.parameters import CompressWeightsMode |
| 22 | +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters |
| 23 | +from nncf.quantization.quantize_model import compress_weights |
| 24 | +from nncf.scopes import IgnoredScope |
| 25 | +from nncf.torch import load_from_config |
18 | 26 | from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ
|
19 | 27 | from nncf.torch.quantization.layers import LoraMixin
|
20 | 28 | from nncf.torch.quantization.layers import SymmetricQuantizer as SQ
|
| 29 | +from tests.torch.test_models.synthetic import LinearModel |
21 | 30 |
|
22 | 31 |
|
23 | 32 | @pytest.mark.parametrize(
|
@@ -80,3 +89,70 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable
|
80 | 89 |
|
81 | 90 | assert first_loss > 8
|
82 | 91 | assert float(loss) < 1
|
| 92 | + |
| 93 | + |
| 94 | +def test_checkpoint_loading(tmp_path): |
| 95 | + model_id = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" |
| 96 | + if not torch.cuda.is_available(): |
| 97 | + pytest.skip("Skipping CUDA test case for CPU only setups.") |
| 98 | + device = "cuda" |
| 99 | + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") |
| 100 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 101 | + example_input = tokenizer("dummy", return_tensors="pt").to(device) |
| 102 | + ref_output = tokenizer.decode( |
| 103 | + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True |
| 104 | + ) |
| 105 | + except_lm_head_and_5th_vproj = ( |
| 106 | + r"^(?!.*(GPTNeoXLayer\[2\]/GPTNeoXSdpaAttention\[attention\]/Linear\[query_key_value\]/l|embed_out).*$).*$" |
| 107 | + ) |
| 108 | + model = compress_weights( |
| 109 | + model, |
| 110 | + group_size=32, |
| 111 | + mode=CompressWeightsMode.INT4_ASYM, |
| 112 | + backup_mode=CompressWeightsMode.INT8_ASYM, |
| 113 | + dataset=Dataset([dict(example_input)]), |
| 114 | + compression_format=CompressionFormat.FQ_LORA, |
| 115 | + ignored_scope=IgnoredScope(patterns=[except_lm_head_and_5th_vproj]), |
| 116 | + advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=2), |
| 117 | + ) |
| 118 | + ref_output = tokenizer.decode( |
| 119 | + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], skip_special_tokens=True |
| 120 | + ) |
| 121 | + |
| 122 | + # save checkpoint |
| 123 | + ckpt_path = tmp_path / "nncf_ckpt.pth" |
| 124 | + torch.save( |
| 125 | + { |
| 126 | + "nncf_state_dict": model.nncf.state_dict(), |
| 127 | + "nncf_config": model.nncf.get_config(), |
| 128 | + }, |
| 129 | + ckpt_path, |
| 130 | + ) |
| 131 | + del model |
| 132 | + |
| 133 | + # load checkpoint |
| 134 | + nncf_ckpt = torch.load(ckpt_path, weights_only=False) |
| 135 | + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") |
| 136 | + model = load_from_config(model, nncf_ckpt["nncf_config"], example_input=dict(example_input)) |
| 137 | + model.nncf.load_state_dict(nncf_ckpt["nncf_state_dict"]) |
| 138 | + |
| 139 | + actual_output = tokenizer.decode( |
| 140 | + model.generate(**example_input, do_sample=False, max_new_tokens=20)[0], |
| 141 | + skip_special_tokens=True, |
| 142 | + ) |
| 143 | + assert actual_output == ref_output |
| 144 | + |
| 145 | + |
| 146 | +def test_invalid_lora_rank(): |
| 147 | + too_big_rank = 4 |
| 148 | + model = LinearModel(torch.ones(2, 2)) |
| 149 | + with pytest.raises(ValidationError): |
| 150 | + compress_weights( |
| 151 | + model, |
| 152 | + mode=CompressWeightsMode.INT4_ASYM, |
| 153 | + group_size=2, |
| 154 | + all_layers=True, |
| 155 | + dataset=Dataset([torch.ones(2, 2)]), |
| 156 | + compression_format=CompressionFormat.FQ_LORA, |
| 157 | + advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=too_big_rank), |
| 158 | + ) |
0 commit comments