Skip to content

Commit 412fd5d

Browse files
committed
initial
1 parent e9860bb commit 412fd5d

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

.ci/cspell_dict.txt

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ baddbmm
4141
batchnorm
4242
batchnorms
4343
batchwise
44+
bfloat16
4445
biasadd
4546
bibsource
4647
bibtex
@@ -407,6 +408,7 @@ shufflenet
407408
signedness
408409
silu
409410
smol
411+
smolm
410412
softmax
411413
sota
412414
sparsifiable
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Quantization-aware tuning with absorbable LoRA Adapters for improving accuracy of 4bit LLMs
2+
3+
This example demonstrates how to improve accuracy of Large Language Models (LLMs) with 4bit weights by
4+
quantization-aware-training with absorbable LoRA adapters.
5+
6+
The example includes the following steps:
7+
8+
- Creation of NNCF model with extended FakeQuantize (FQ) operations on the weights of all linear layers,
9+
except for the embedding and lm_head layers. This FQ includes absorbable LoRA Adapters and it performs fake quantization
10+
in the following way: `dequantize(quantize(W + B @ A))`, where W is the original weight of the linear layer,
11+
and A and B are the LoRA adapters. The compression part of the NNCF model is then saved in the NNCF checkpoint for
12+
tuning and evaluation. It is expected that the initial accuracy of such a model is low, as it currently uses
13+
a data-free Round-To-Nearest quantization scheme. In the next step, accuracy will be significantly improved by tuning
14+
both the quantization scales and the LoRA adapters.
15+
16+
- Tuning pipeline with distillation loss. The teacher model is the original bfloat16 model, while the student model
17+
includes FQ operations. The training dataset is based on the training portion of the `wikitext-2-raw-v1` dataset,
18+
consisting of 1024 samples of length 1024. Validation is performed at the end of each epoch using
19+
[WhoWhatBench](https://github.com/openvinotoolkit/openvino.genai/tree/master/llm_bench/python/who_what_benchmark).
20+
Tuning for 32 epochs on a single A100 card takes around 4 hours for 1.7B models, approximately 6 hours for 3B models,
21+
and about 12 hours for 8B models. The most significant accuracy improvement is typically achieved within the first
22+
1-2 epochs.
23+
24+
## Install requirements
25+
26+
To use this example:
27+
28+
- Create a separate Python* environment and activate it: `python3 -m venv nncf_env && source nncf_env/bin/activate`
29+
- Install dependencies:
30+
31+
```bash
32+
pip install -U pip
33+
pip install -r requirements.txt
34+
pip install ../../../../
35+
```
36+
37+
## Run Example
38+
39+
The example is fully automated. Just run the following command in the prepared Python environment:
40+
41+
```bash
42+
python main.py
43+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
import random
12+
import shutil
13+
import warnings
14+
from copy import deepcopy
15+
from pathlib import Path
16+
from typing import Dict, List, Union
17+
from weakref import WeakKeyDictionary
18+
19+
import torch
20+
import torch.nn.functional as F
21+
from datasets import load_dataset
22+
from optimum.exporters.openvino.convert import export_from_model
23+
from optimum.intel.openvino import OVModelForCausalLM
24+
from torch import Tensor
25+
from torch import nn
26+
from torch.jit import TracerWarning
27+
from torch.utils.tensorboard import SummaryWriter
28+
from tqdm import tqdm
29+
from tqdm import trange
30+
from transformers import AutoModelForCausalLM
31+
from transformers import AutoTokenizer
32+
from whowhatbench import TextEvaluator
33+
34+
import nncf
35+
from nncf.data.dataset import Dataset
36+
from nncf.parameters import CompressionFormat
37+
from nncf.parameters import CompressWeightsMode
38+
from nncf.quantization.quantize_model import compress_weights
39+
from nncf.torch.quantization.layers import AsymmetricLoraQuantizer
40+
from nncf.torch.quantization.layers import BaseWeightsDecompressor
41+
from nncf.torch.quantization.layers import SymmetricLoraQuantizer
42+
43+
MODEL_ID = "HuggingFaceTB/SmolLM-1.7B-Instruct"
44+
DEVICE = "cuda"
45+
TORCH_DTYPE = torch.bfloat16
46+
47+
48+
ROOT = Path(__file__).parent.resolve()
49+
OUTPUT_DIR = ROOT / "output"
50+
TENSORBOARD_DIR = OUTPUT_DIR / "tb"
51+
LAST_DIR = OUTPUT_DIR / "last"
52+
BEST_DIR = LAST_DIR / "best"
53+
for path in [OUTPUT_DIR, TENSORBOARD_DIR, LAST_DIR, BEST_DIR]:
54+
path.mkdir(exist_ok=True, parents=True)
55+
WWB_REF_FILE = OUTPUT_DIR / "wwb_ref.csv"
56+
57+
58+
# TODO: (nlyalyus) move to Optimum-Intel (ticket 164159)
59+
class PatchDecompressorDtype:
60+
"""
61+
Patching of compression modules in order to export bfloat16 models to OV.
62+
"""
63+
64+
def __init__(self, model):
65+
self.model = model
66+
self.modules_map: WeakKeyDictionary[nn.Module, List[str]] = WeakKeyDictionary()
67+
68+
def __enter__(self):
69+
model_layout = self.model.nncf.transformation_layout()
70+
transformations = model_layout.transformations
71+
for command in transformations:
72+
decompressor = command.fn
73+
if isinstance(decompressor, BaseWeightsDecompressor):
74+
self.modules_map[decompressor] = decompressor.result_dtype
75+
decompressor.result_dtype = torch.float32
76+
77+
def __exit__(self, *args):
78+
print("exit args=", args)
79+
for decompressor, dtype in self.modules_map.items():
80+
decompressor.result_dtype = dtype
81+
82+
83+
def get_wikitext2(nsamples, seqlen, tokenizer):
84+
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
85+
limit = nsamples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum
86+
text = "".join([" \n" if s == "" else s for s in traindata["text"][:limit]])
87+
trainenc = tokenizer(text, return_tensors="pt")
88+
trainloader = []
89+
for _ in range(nsamples):
90+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
91+
j = i + seqlen
92+
inp = trainenc.input_ids[:, i:j].to(DEVICE)
93+
# TODO: or recompute attention_mask/position_ids on tuning?
94+
attention_mask = torch.ones_like(inp)
95+
position_ids = torch.cumsum(attention_mask, axis=1) - 1
96+
trainloader.append({"input_ids": inp, "attention_mask": attention_mask, "position_ids": position_ids})
97+
return trainloader
98+
99+
100+
def set_seed(seed):
101+
torch.manual_seed(seed)
102+
torch.cuda.manual_seed(seed)
103+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
104+
random.seed(seed) # Python random module.
105+
torch.backends.cudnn.benchmark = False
106+
torch.backends.cudnn.deterministic = True
107+
108+
109+
def save_wwb_ref(model, tokenizer):
110+
if not WWB_REF_FILE.exists():
111+
wwb_eval = TextEvaluator(base_model=model, tokenizer=tokenizer, use_chat_template=True)
112+
wwb_eval.dump_gt(str(WWB_REF_FILE))
113+
114+
115+
def get_similarity(model, wwb_eval, ir_dir):
116+
print("#" * 50 + " Evaluate via WWB" + "#" * 50)
117+
model = nncf.strip(model)
118+
with PatchDecompressorDtype(model), warnings.catch_warnings():
119+
warnings.simplefilter("ignore", category=TracerWarning)
120+
export_from_model(model.cpu(), ir_dir, patch_16bit_model=True, device="cpu")
121+
ov_model = OVModelForCausalLM.from_pretrained(
122+
model_id=ir_dir,
123+
trust_remote_code=True,
124+
load_in_8bit=False,
125+
compile=True,
126+
ov_config={"KV_CACHE_PRECISION": "f16", "DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"},
127+
)
128+
_, all_metrics = wwb_eval.score(ov_model)
129+
return float(all_metrics["similarity"].iloc[0])
130+
131+
132+
def print_trainable_parameters(module):
133+
params = list(module.parameters())
134+
trainable_params = sum(p.numel() for p in params if p.requires_grad)
135+
all_param = sum(p.numel() for p in params)
136+
print(
137+
f"trainable params: {trainable_params:,d} || "
138+
f"all params: {all_param:,d} || "
139+
f"trainable%: {100 * trainable_params / all_param:.4f}"
140+
)
141+
142+
143+
@torch.inference_mode()
144+
def calc_hiddens(model, dataloader):
145+
orig_hiddens = []
146+
for i in trange(len(dataloader), total=len(dataloader), desc="Calculating original hiddens", leave=False):
147+
orig_hiddens.append(model.model(**dataloader[i]).last_hidden_state)
148+
return orig_hiddens
149+
150+
151+
def kl_div(student_hiddens, teacher_hiddens):
152+
C = student_hiddens.shape[-1] # num classes
153+
return F.kl_div(
154+
input=F.log_softmax(student_hiddens.view(-1, C), dim=-1),
155+
target=F.log_softmax(teacher_hiddens.view(-1, C), dim=-1),
156+
log_target=True,
157+
reduction="batchmean",
158+
)
159+
160+
161+
def set_trainable(model, lora_lr, fq_lr):
162+
model.requires_grad_(False)
163+
scales_to_train = []
164+
adapters_to_train = []
165+
transformations = model.nncf.transformation_layout().transformations
166+
for command in transformations:
167+
quantizer = command.fn
168+
if isinstance(quantizer, (AsymmetricLoraQuantizer, SymmetricLoraQuantizer)) and (quantizer.num_bits == 4):
169+
quantizer.enable_gradients()
170+
params = quantizer.get_trainable_params()
171+
adapters = quantizer.get_adapters()
172+
adapters_to_train.extend(adapters.values())
173+
scales_to_train.extend(param for name, param in params.items() if name not in adapters)
174+
print_trainable_parameters(model)
175+
return [{"params": adapters_to_train, "lr": lora_lr}, {"params": scales_to_train, "lr": fq_lr}]
176+
177+
178+
def main():
179+
assert torch.cuda.is_available()
180+
set_seed(42)
181+
182+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=TORCH_DTYPE, device_map=DEVICE)
183+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
184+
185+
save_wwb_ref(model, tokenizer)
186+
187+
train_loader = get_wikitext2(nsamples=1024, seqlen=1024, tokenizer=tokenizer)
188+
orig_hiddens = calc_hiddens(model, train_loader)
189+
190+
example_input = train_loader[0]
191+
model = compress_weights(
192+
model,
193+
mode=CompressWeightsMode.INT4_ASYM,
194+
group_size=64,
195+
dataset=Dataset([example_input]),
196+
compression_format=CompressionFormat.FQ_LORA,
197+
)
198+
199+
microbatch_size = 2
200+
batch_size = 32
201+
grad_accumulation_steps = batch_size // microbatch_size
202+
num_samples = len(train_loader)
203+
epoch_samples = num_samples - num_samples % microbatch_size
204+
microbatches_per_epoch = epoch_samples // microbatch_size
205+
206+
tb = SummaryWriter(TENSORBOARD_DIR, "QAT with absorbable LoRA")
207+
208+
wwb_eval = TextEvaluator(
209+
tokenizer=tokenizer, gt_data=WWB_REF_FILE, test_data=str(WWB_REF_FILE), use_chat_template=True
210+
)
211+
best_similarity = get_similarity(model, wwb_eval, LAST_DIR)
212+
print(f"WWB similarity for initial 4bit model= {best_similarity:.4f}")
213+
lm_head = deepcopy(model.lm_head)
214+
lm_head.requires_grad_(False)
215+
216+
param_to_train = set_trainable(model, lora_lr=5e-4, fq_lr=5e-5)
217+
opt = torch.optim.AdamW(param_to_train, weight_decay=5e-4)
218+
model.train()
219+
220+
aggregated_loss = float("nan")
221+
loss_numerator = grad_steps = total_microbatches = 0
222+
for epoch in range(32):
223+
batch_indices_epoch = torch.randperm(num_samples)[:epoch_samples].chunk(microbatches_per_epoch)
224+
for batch_indices in tqdm(batch_indices_epoch, desc=f"Train epoch {epoch}", leave=[False]):
225+
batch_indices = batch_indices.tolist()
226+
total_microbatches += 1
227+
228+
def form_batch(inputs: List[Union[Dict[str, Tensor], Tensor]], indices: List[int]):
229+
if isinstance(inputs[0], dict):
230+
batch = {name: torch.cat([inputs[i][name] for i in indices], dim=0) for name in inputs[0]}
231+
else:
232+
batch = torch.cat([inputs[i] for i in indices], dim=0).to(device=DEVICE, dtype=TORCH_DTYPE)
233+
return batch
234+
235+
inputs = form_batch(train_loader, batch_indices)
236+
with torch.no_grad():
237+
targets = lm_head(form_batch(orig_hiddens, batch_indices))
238+
if hasattr(model.config, "final_logit_softcapping"): # Gemma
239+
fls = model.config.final_logit_softcapping
240+
if fls is not None:
241+
targets = targets / fls
242+
targets = torch.tanh(targets)
243+
targets = targets * fls
244+
245+
outputs = model(**inputs).logits
246+
loss = kl_div(outputs, targets.to(dtype=TORCH_DTYPE))
247+
248+
loss_numerator += loss.item()
249+
grad_steps += 1
250+
251+
if not torch.isfinite(loss).item():
252+
err = f"Fine-tuning loss is {loss}"
253+
raise ValueError(err)
254+
255+
(loss / grad_accumulation_steps).backward()
256+
257+
if grad_steps == grad_accumulation_steps:
258+
opt.step()
259+
opt.zero_grad()
260+
aggregated_loss = loss_numerator / grad_steps
261+
loss_numerator = grad_steps = 0
262+
263+
tb.add_scalar("loss", aggregated_loss, total_microbatches)
264+
265+
smlr = get_similarity(model, wwb_eval, LAST_DIR)
266+
print(f"WWB similarity = {smlr:.4f}")
267+
tb.add_scalar("similarity", smlr, total_microbatches)
268+
if smlr > best_similarity:
269+
print(f"New best WWB similarity = {smlr:.4f}")
270+
best_similarity = smlr
271+
shutil.copytree(LAST_DIR, BEST_DIR, dirs_exist_ok=True)
272+
273+
print(f"Finetuned OV model has similarity={best_similarity} and is located here: {BEST_DIR}")
274+
275+
276+
if __name__ == "__main__":
277+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
tqdm
2+
whowhatbench @ git+https://github.com/openvinotoolkit/openvino.genai#subdirectory=tools/who_what_benchmark
3+
numpy>=1.23.5,<2
4+
openvino==2025.0
5+
optimum-intel>=1.22.0
6+
transformers>=4.48.0

0 commit comments

Comments
 (0)