|
| 1 | +# Copyright (c) 2025 Intel Corporation |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | +import random |
| 12 | +import shutil |
| 13 | +import warnings |
| 14 | +from copy import deepcopy |
| 15 | +from pathlib import Path |
| 16 | +from typing import Dict, List, Union |
| 17 | +from weakref import WeakKeyDictionary |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn.functional as F |
| 21 | +from datasets import load_dataset |
| 22 | +from optimum.exporters.openvino.convert import export_from_model |
| 23 | +from optimum.intel.openvino import OVModelForCausalLM |
| 24 | +from torch import Tensor |
| 25 | +from torch import nn |
| 26 | +from torch.jit import TracerWarning |
| 27 | +from torch.utils.tensorboard import SummaryWriter |
| 28 | +from tqdm import tqdm |
| 29 | +from tqdm import trange |
| 30 | +from transformers import AutoModelForCausalLM |
| 31 | +from transformers import AutoTokenizer |
| 32 | +from whowhatbench import TextEvaluator |
| 33 | + |
| 34 | +import nncf |
| 35 | +from nncf.data.dataset import Dataset |
| 36 | +from nncf.parameters import CompressionFormat |
| 37 | +from nncf.parameters import CompressWeightsMode |
| 38 | +from nncf.quantization.quantize_model import compress_weights |
| 39 | +from nncf.torch.quantization.layers import AsymmetricLoraQuantizer |
| 40 | +from nncf.torch.quantization.layers import BaseWeightsDecompressor |
| 41 | +from nncf.torch.quantization.layers import SymmetricLoraQuantizer |
| 42 | + |
| 43 | +MODEL_ID = "HuggingFaceTB/SmolLM-1.7B-Instruct" |
| 44 | +DEVICE = "cuda" |
| 45 | +TORCH_DTYPE = torch.bfloat16 |
| 46 | + |
| 47 | + |
| 48 | +ROOT = Path(__file__).parent.resolve() |
| 49 | +OUTPUT_DIR = ROOT / "output" |
| 50 | +TENSORBOARD_DIR = OUTPUT_DIR / "tb" |
| 51 | +LAST_DIR = OUTPUT_DIR / "last" |
| 52 | +BEST_DIR = LAST_DIR / "best" |
| 53 | +for path in [OUTPUT_DIR, TENSORBOARD_DIR, LAST_DIR, BEST_DIR]: |
| 54 | + path.mkdir(exist_ok=True, parents=True) |
| 55 | +WWB_REF_FILE = OUTPUT_DIR / "wwb_ref.csv" |
| 56 | + |
| 57 | + |
| 58 | +# TODO: (nlyalyus) move to Optimum-Intel (ticket 164159) |
| 59 | +class PatchDecompressorDtype: |
| 60 | + """ |
| 61 | + Patching of compression modules in order to export bfloat16 models to OV. |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, model): |
| 65 | + self.model = model |
| 66 | + self.modules_map: WeakKeyDictionary[nn.Module, List[str]] = WeakKeyDictionary() |
| 67 | + |
| 68 | + def __enter__(self): |
| 69 | + model_layout = self.model.nncf.transformation_layout() |
| 70 | + transformations = model_layout.transformations |
| 71 | + for command in transformations: |
| 72 | + decompressor = command.fn |
| 73 | + if isinstance(decompressor, BaseWeightsDecompressor): |
| 74 | + self.modules_map[decompressor] = decompressor.result_dtype |
| 75 | + decompressor.result_dtype = torch.float32 |
| 76 | + |
| 77 | + def __exit__(self, *args): |
| 78 | + print("exit args=", args) |
| 79 | + for decompressor, dtype in self.modules_map.items(): |
| 80 | + decompressor.result_dtype = dtype |
| 81 | + |
| 82 | + |
| 83 | +def get_wikitext2(nsamples, seqlen, tokenizer): |
| 84 | + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") |
| 85 | + limit = nsamples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum |
| 86 | + text = "".join([" \n" if s == "" else s for s in traindata["text"][:limit]]) |
| 87 | + trainenc = tokenizer(text, return_tensors="pt") |
| 88 | + trainloader = [] |
| 89 | + for _ in range(nsamples): |
| 90 | + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) |
| 91 | + j = i + seqlen |
| 92 | + inp = trainenc.input_ids[:, i:j].to(DEVICE) |
| 93 | + # TODO: or recompute attention_mask/position_ids on tuning? |
| 94 | + attention_mask = torch.ones_like(inp) |
| 95 | + position_ids = torch.cumsum(attention_mask, axis=1) - 1 |
| 96 | + trainloader.append({"input_ids": inp, "attention_mask": attention_mask, "position_ids": position_ids}) |
| 97 | + return trainloader |
| 98 | + |
| 99 | + |
| 100 | +def set_seed(seed): |
| 101 | + torch.manual_seed(seed) |
| 102 | + torch.cuda.manual_seed(seed) |
| 103 | + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. |
| 104 | + random.seed(seed) # Python random module. |
| 105 | + torch.backends.cudnn.benchmark = False |
| 106 | + torch.backends.cudnn.deterministic = True |
| 107 | + |
| 108 | + |
| 109 | +def save_wwb_ref(model, tokenizer): |
| 110 | + if not WWB_REF_FILE.exists(): |
| 111 | + wwb_eval = TextEvaluator(base_model=model, tokenizer=tokenizer, use_chat_template=True) |
| 112 | + wwb_eval.dump_gt(str(WWB_REF_FILE)) |
| 113 | + |
| 114 | + |
| 115 | +def get_similarity(model, wwb_eval, ir_dir): |
| 116 | + print("#" * 50 + " Evaluate via WWB" + "#" * 50) |
| 117 | + model = nncf.strip(model) |
| 118 | + with PatchDecompressorDtype(model), warnings.catch_warnings(): |
| 119 | + warnings.simplefilter("ignore", category=TracerWarning) |
| 120 | + export_from_model(model.cpu(), ir_dir, patch_16bit_model=True, device="cpu") |
| 121 | + ov_model = OVModelForCausalLM.from_pretrained( |
| 122 | + model_id=ir_dir, |
| 123 | + trust_remote_code=True, |
| 124 | + load_in_8bit=False, |
| 125 | + compile=True, |
| 126 | + ov_config={"KV_CACHE_PRECISION": "f16", "DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"}, |
| 127 | + ) |
| 128 | + _, all_metrics = wwb_eval.score(ov_model) |
| 129 | + return float(all_metrics["similarity"].iloc[0]) |
| 130 | + |
| 131 | + |
| 132 | +def print_trainable_parameters(module): |
| 133 | + params = list(module.parameters()) |
| 134 | + trainable_params = sum(p.numel() for p in params if p.requires_grad) |
| 135 | + all_param = sum(p.numel() for p in params) |
| 136 | + print( |
| 137 | + f"trainable params: {trainable_params:,d} || " |
| 138 | + f"all params: {all_param:,d} || " |
| 139 | + f"trainable%: {100 * trainable_params / all_param:.4f}" |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +@torch.inference_mode() |
| 144 | +def calc_hiddens(model, dataloader): |
| 145 | + orig_hiddens = [] |
| 146 | + for i in trange(len(dataloader), total=len(dataloader), desc="Calculating original hiddens", leave=False): |
| 147 | + orig_hiddens.append(model.model(**dataloader[i]).last_hidden_state) |
| 148 | + return orig_hiddens |
| 149 | + |
| 150 | + |
| 151 | +def kl_div(student_hiddens, teacher_hiddens): |
| 152 | + C = student_hiddens.shape[-1] # num classes |
| 153 | + return F.kl_div( |
| 154 | + input=F.log_softmax(student_hiddens.view(-1, C), dim=-1), |
| 155 | + target=F.log_softmax(teacher_hiddens.view(-1, C), dim=-1), |
| 156 | + log_target=True, |
| 157 | + reduction="batchmean", |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +def set_trainable(model, lora_lr, fq_lr): |
| 162 | + model.requires_grad_(False) |
| 163 | + scales_to_train = [] |
| 164 | + adapters_to_train = [] |
| 165 | + transformations = model.nncf.transformation_layout().transformations |
| 166 | + for command in transformations: |
| 167 | + quantizer = command.fn |
| 168 | + if isinstance(quantizer, (AsymmetricLoraQuantizer, SymmetricLoraQuantizer)) and (quantizer.num_bits == 4): |
| 169 | + quantizer.enable_gradients() |
| 170 | + params = quantizer.get_trainable_params() |
| 171 | + adapters = quantizer.get_adapters() |
| 172 | + adapters_to_train.extend(adapters.values()) |
| 173 | + scales_to_train.extend(param for name, param in params.items() if name not in adapters) |
| 174 | + print_trainable_parameters(model) |
| 175 | + return [{"params": adapters_to_train, "lr": lora_lr}, {"params": scales_to_train, "lr": fq_lr}] |
| 176 | + |
| 177 | + |
| 178 | +def main(): |
| 179 | + assert torch.cuda.is_available() |
| 180 | + set_seed(42) |
| 181 | + |
| 182 | + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=TORCH_DTYPE, device_map=DEVICE) |
| 183 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| 184 | + |
| 185 | + save_wwb_ref(model, tokenizer) |
| 186 | + |
| 187 | + train_loader = get_wikitext2(nsamples=1024, seqlen=1024, tokenizer=tokenizer) |
| 188 | + orig_hiddens = calc_hiddens(model, train_loader) |
| 189 | + |
| 190 | + example_input = train_loader[0] |
| 191 | + model = compress_weights( |
| 192 | + model, |
| 193 | + mode=CompressWeightsMode.INT4_ASYM, |
| 194 | + group_size=64, |
| 195 | + dataset=Dataset([example_input]), |
| 196 | + compression_format=CompressionFormat.FQ_LORA, |
| 197 | + ) |
| 198 | + |
| 199 | + microbatch_size = 2 |
| 200 | + batch_size = 32 |
| 201 | + grad_accumulation_steps = batch_size // microbatch_size |
| 202 | + num_samples = len(train_loader) |
| 203 | + epoch_samples = num_samples - num_samples % microbatch_size |
| 204 | + microbatches_per_epoch = epoch_samples // microbatch_size |
| 205 | + |
| 206 | + tb = SummaryWriter(TENSORBOARD_DIR, "QAT with absorbable LoRA") |
| 207 | + |
| 208 | + wwb_eval = TextEvaluator( |
| 209 | + tokenizer=tokenizer, gt_data=WWB_REF_FILE, test_data=str(WWB_REF_FILE), use_chat_template=True |
| 210 | + ) |
| 211 | + best_similarity = get_similarity(model, wwb_eval, LAST_DIR) |
| 212 | + print(f"WWB similarity for initial 4bit model= {best_similarity:.4f}") |
| 213 | + lm_head = deepcopy(model.lm_head) |
| 214 | + lm_head.requires_grad_(False) |
| 215 | + |
| 216 | + param_to_train = set_trainable(model, lora_lr=5e-4, fq_lr=5e-5) |
| 217 | + opt = torch.optim.AdamW(param_to_train, weight_decay=5e-4) |
| 218 | + model.train() |
| 219 | + |
| 220 | + aggregated_loss = float("nan") |
| 221 | + loss_numerator = grad_steps = total_microbatches = 0 |
| 222 | + for epoch in range(32): |
| 223 | + batch_indices_epoch = torch.randperm(num_samples)[:epoch_samples].chunk(microbatches_per_epoch) |
| 224 | + for batch_indices in tqdm(batch_indices_epoch, desc=f"Train epoch {epoch}", leave=[False]): |
| 225 | + batch_indices = batch_indices.tolist() |
| 226 | + total_microbatches += 1 |
| 227 | + |
| 228 | + def form_batch(inputs: List[Union[Dict[str, Tensor], Tensor]], indices: List[int]): |
| 229 | + if isinstance(inputs[0], dict): |
| 230 | + batch = {name: torch.cat([inputs[i][name] for i in indices], dim=0) for name in inputs[0]} |
| 231 | + else: |
| 232 | + batch = torch.cat([inputs[i] for i in indices], dim=0).to(device=DEVICE, dtype=TORCH_DTYPE) |
| 233 | + return batch |
| 234 | + |
| 235 | + inputs = form_batch(train_loader, batch_indices) |
| 236 | + with torch.no_grad(): |
| 237 | + targets = lm_head(form_batch(orig_hiddens, batch_indices)) |
| 238 | + if hasattr(model.config, "final_logit_softcapping"): # Gemma |
| 239 | + fls = model.config.final_logit_softcapping |
| 240 | + if fls is not None: |
| 241 | + targets = targets / fls |
| 242 | + targets = torch.tanh(targets) |
| 243 | + targets = targets * fls |
| 244 | + |
| 245 | + outputs = model(**inputs).logits |
| 246 | + loss = kl_div(outputs, targets.to(dtype=TORCH_DTYPE)) |
| 247 | + |
| 248 | + loss_numerator += loss.item() |
| 249 | + grad_steps += 1 |
| 250 | + |
| 251 | + if not torch.isfinite(loss).item(): |
| 252 | + err = f"Fine-tuning loss is {loss}" |
| 253 | + raise ValueError(err) |
| 254 | + |
| 255 | + (loss / grad_accumulation_steps).backward() |
| 256 | + |
| 257 | + if grad_steps == grad_accumulation_steps: |
| 258 | + opt.step() |
| 259 | + opt.zero_grad() |
| 260 | + aggregated_loss = loss_numerator / grad_steps |
| 261 | + loss_numerator = grad_steps = 0 |
| 262 | + |
| 263 | + tb.add_scalar("loss", aggregated_loss, total_microbatches) |
| 264 | + |
| 265 | + smlr = get_similarity(model, wwb_eval, LAST_DIR) |
| 266 | + print(f"WWB similarity = {smlr:.4f}") |
| 267 | + tb.add_scalar("similarity", smlr, total_microbatches) |
| 268 | + if smlr > best_similarity: |
| 269 | + print(f"New best WWB similarity = {smlr:.4f}") |
| 270 | + best_similarity = smlr |
| 271 | + shutil.copytree(LAST_DIR, BEST_DIR, dirs_exist_ok=True) |
| 272 | + |
| 273 | + print(f"Finetuned OV model has similarity={best_similarity} and is located here: {BEST_DIR}") |
| 274 | + |
| 275 | + |
| 276 | +if __name__ == "__main__": |
| 277 | + main() |
0 commit comments