Skip to content

Commit e7b4b64

Browse files
authored
PyTorch TEQ Weight-only 3x API Implementation (#1598)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Signed-off-by: Tang, Kaihui <kaihui.tang@intel.com>
1 parent c4010bc commit e7b4b64

File tree

10 files changed

+702
-8
lines changed

10 files changed

+702
-8
lines changed

neural_compressor/torch/algorithms/weight_only/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .rtn import rtn_quantize
1616
from .gptq import gptq_quantize
1717
from .awq import awq_quantize
18+
from .teq import teq_quantize
1819
from .hqq import hqq_quantize
1920
from .modules import WeightOnlyLinear
2021
from .utility import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
import torch
20+
import transformers
21+
22+
from neural_compressor.torch.utils import logger
23+
24+
from .modules import MulLinear, TEQLinearFakeQuant
25+
from .utility import get_module, quant_tensor, set_module
26+
27+
28+
class TEQuantizer:
29+
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
30+
31+
def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
32+
"""
33+
:param model: the model for quantization
34+
:param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
35+
:param example_inputs: inputs for trace
36+
"""
37+
self.model = model
38+
self.weight_config = weight_config
39+
self.folding = folding
40+
self.example_inputs = example_inputs
41+
self.device, self.dtype = self._get_device()
42+
self.model.eval()
43+
self.trained_alphas = {}
44+
self.absorb_to_layer = absorb_to_layer
45+
46+
def _get_device(self):
47+
"""Get the model device
48+
:return:Model device."""
49+
for _, p in self.model.named_parameters():
50+
return p.data.device, p.data.dtype
51+
52+
def add_tuning_scale(self, sqrt_w_init=False):
53+
"""The main entry of smooth quant
54+
to the paper for more details
55+
:param sqrt_w_init: use sqrt weight to init."""
56+
57+
# freeze model.
58+
for n, p in self.model.named_parameters():
59+
p.requires_grad = False
60+
61+
for layer_norm in self.absorb_to_layer:
62+
layer_0_name = self.absorb_to_layer[layer_norm][0]
63+
64+
module = get_module(self.model, layer_0_name)
65+
66+
if sqrt_w_init: # pragma: no cover
67+
weights = []
68+
for layer_name in self.absorb_to_layer[layer_norm]:
69+
module = get_module(self.model, layer_name)
70+
weights.append(module.weight)
71+
72+
weights = torch.cat(weights, dim=0)
73+
max_value = torch.sqrt(torch.max(torch.abs(weights), dim=0).values)
74+
max_value[max_value == 0] = 1.0
75+
max_value = 1.0 / max_value
76+
77+
alpha = torch.nn.Parameter(max_value)
78+
alpha = alpha.to(self.device)
79+
else:
80+
alpha = torch.nn.Parameter(torch.ones(module.weight.shape[1], device=self.device))
81+
82+
self.trained_alphas[layer_norm] = alpha
83+
for layer_name in self.absorb_to_layer[layer_norm]:
84+
if self.weight_config.get(layer_name) is None: # pragma: no cover
85+
logger.info(f"layer {layer_name} not in weight config, skip.")
86+
continue
87+
num_bits = self.weight_config[layer_name]["bits"]
88+
group_size = self.weight_config[layer_name]["group_size"]
89+
scheme = self.weight_config[layer_name]["scheme"]
90+
91+
module = get_module(self.model, layer_name)
92+
wrapper_module = TEQLinearFakeQuant(
93+
orig_layer=module, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
94+
)
95+
set_module(self.model, layer_name, wrapper_module)
96+
97+
for n, m in self.model.named_modules():
98+
if isinstance(m, torch.nn.Linear) and "orig_layer" not in n:
99+
if self.weight_config.get(n) is None: # pragma: no cover
100+
logger.info(f"out of absorbed layer {n} not in weight config, skip.")
101+
continue
102+
num_bits = self.weight_config[layer_name]["bits"]
103+
group_size = self.weight_config[layer_name]["group_size"]
104+
scheme = self.weight_config[layer_name]["scheme"]
105+
106+
alpha = torch.nn.Parameter(torch.ones(m.weight.shape[1], device=self.device))
107+
alpha.requires_grad_(False)
108+
wrapper_module = TEQLinearFakeQuant(
109+
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
110+
)
111+
set_module(self.model, n, wrapper_module)
112+
113+
@torch.no_grad()
114+
def _absorb_scales(self, layer, scale, layer_name=""):
115+
"""Absorb the scale to the layer at output channel
116+
:param layer: The module
117+
:param scale: The scale to be absorbed
118+
:param layer_name: The layer name."""
119+
# for insert mul
120+
if not self.folding: # pragma: no cover
121+
if isinstance(layer, MulLinear):
122+
set_module(self.model, layer_name, layer.linear) ##recover
123+
else:
124+
new_module = MulLinear(layer, scale)
125+
set_module(self.model, layer_name, new_module)
126+
self.weight_config[layer_name + ".linear"] = self.weight_config[layer_name]
127+
return
128+
129+
if (
130+
isinstance(layer, torch.nn.BatchNorm2d)
131+
or isinstance(layer, torch.nn.GroupNorm)
132+
or isinstance(layer, torch.nn.InstanceNorm2d)
133+
):
134+
if layer.affine: # pragma: no cover
135+
layer.weight *= scale
136+
layer.bias *= scale
137+
else: # pragma: no cover
138+
layer.affine = True
139+
weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale
140+
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
141+
bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype)
142+
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
143+
elif isinstance(layer, torch.nn.LayerNorm):
144+
if layer.elementwise_affine:
145+
layer.weight *= scale
146+
layer.bias *= scale
147+
else: # pragma: no cover
148+
layer.elementwise_affine = True
149+
weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale
150+
layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False))
151+
bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype)
152+
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
153+
154+
elif isinstance(layer, torch.nn.Conv2d): # pragma: no cover
155+
## the order could not be changed
156+
if hasattr(layer, "bias") and (layer.bias is not None):
157+
layer.bias *= scale
158+
scale = scale.view(scale.shape[0], 1, 1, 1)
159+
layer.weight *= scale
160+
161+
elif isinstance(layer, torch.nn.Linear): # pragma: no cover
162+
if hasattr(layer, "bias") and (layer.bias is not None):
163+
layer.bias *= scale
164+
scale = scale.view(scale.shape[0], 1)
165+
layer.weight *= scale
166+
167+
elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky
168+
layer.weight *= scale
169+
170+
else: # pragma: no cover
171+
logger.info(
172+
f"found unsupported layer {type(layer)}, try to multiply scale to "
173+
f"weight and bias directly, this may introduce accuracy issue, please have a check "
174+
)
175+
if hasattr(layer, "weight") and layer.weight is not None:
176+
layer.weight *= scale
177+
if hasattr(layer, "bias") and layer.bias is not None:
178+
layer.bias *= scale
179+
180+
@torch.no_grad()
181+
def _scale_layer_weight(self, layer, scale): ##input channel
182+
"""Scale the layer weights at input channel, depthwise conv output channel
183+
:param layer_name: The layer name
184+
:param scale: The scale to be multiplied
185+
:return:"""
186+
if layer.__class__.__name__ == "MulLinear":
187+
layer = layer.linear
188+
189+
if layer.__class__.__name__ == "TEQLinearFakeQuant":
190+
layer = layer.orig_layer
191+
192+
scale = scale.view(1, scale.shape[0])
193+
layer.weight = torch.nn.Parameter(layer.weight * scale)
194+
return scale
195+
196+
@torch.no_grad()
197+
def transform(self):
198+
"""Apply alpha/scale."""
199+
for ln_name, layer_names in self.absorb_to_layer.items():
200+
module = get_module(self.model, ln_name)
201+
scale = self.trained_alphas[ln_name]
202+
scale = torch.clip(scale, 1e-5)
203+
input_scale = 1.0 / scale
204+
if hasattr(module, "orig_layer"):
205+
module = module.orig_layer
206+
207+
self._absorb_scales(module, input_scale, layer_name=ln_name)
208+
weight_scale = scale
209+
for layer_name in layer_names:
210+
layer_module = get_module(self.model, layer_name)
211+
self._scale_layer_weight(layer_module, weight_scale)
212+
213+
# for Folding = True
214+
for n, m in self.model.named_modules():
215+
if isinstance(m, TEQLinearFakeQuant):
216+
set_module(self.model, n, m.orig_layer)
217+
218+
def train(
219+
self,
220+
dataloader,
221+
train_steps=1000,
222+
lr=1e-3,
223+
warmup_ratio=0.05,
224+
gradient_accumulation_steps=1,
225+
logging_steps=10,
226+
betas=[0.9, 0.9],
227+
weight_decay=0,
228+
lr_scheduler_type="linear",
229+
):
230+
"""Train function."""
231+
trained_alphas_list = []
232+
for item in self.trained_alphas.items():
233+
trained_alphas_list.append(item[1])
234+
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)
235+
236+
lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
237+
name=lr_scheduler_type,
238+
optimizer=optimizer,
239+
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
240+
num_training_steps=train_steps // gradient_accumulation_steps,
241+
)
242+
243+
logger.info("start training")
244+
self.model.train()
245+
global_steps = 0
246+
247+
while global_steps <= train_steps:
248+
for inputs in dataloader:
249+
if isinstance(inputs, torch.Tensor):
250+
input_id = inputs
251+
elif isinstance(inputs, dict):
252+
input_id = inputs["input_ids"]
253+
else:
254+
input_id = inputs[0]
255+
256+
input_id = input_id.to(self.device)
257+
output = self.model(input_id, labels=input_id)
258+
loss = output[0] / gradient_accumulation_steps
259+
loss.backward()
260+
global_steps += 1
261+
262+
if global_steps % logging_steps == 0:
263+
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))
264+
265+
if global_steps % gradient_accumulation_steps == 0:
266+
optimizer.step()
267+
optimizer.zero_grad()
268+
lr_scheduler.step()
269+
270+
if global_steps >= train_steps: # pragma: no cover
271+
break
272+
273+
logger.info("finish training")
274+
self.model.eval()
275+
return None
276+
277+
@torch.no_grad()
278+
def quantize(self):
279+
"""quantization."""
280+
281+
for n, m in self.model.named_modules():
282+
if self.weight_config.get(n) is None: # pragma: no cover
283+
logger.info(f"quantize layer {n} not in weight config, skip.")
284+
continue
285+
num_bits = self.weight_config[n]["bits"]
286+
group_size = self.weight_config[n]["group_size"]
287+
scheme = self.weight_config[n]["scheme"]
288+
if isinstance(m, torch.nn.Linear): # pragma: no cover
289+
quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)
290+
291+
def save(self, save_scale_file="", save_state_dict_file=""):
292+
"""
293+
save alpha/scale or model weight
294+
:param save_scale_file: save alpha/scale with torch.save
295+
:param save_state_dict_file: save model state_dict
296+
"""
297+
if save_scale_file: # pragma: no cover
298+
torch.save(self.trained_alphas, save_scale_file)
299+
300+
if save_state_dict_file: # pragma: no cover
301+
torch.save(self.model.state_dict(), save_state_dict_file)
302+
303+
304+
def teq_quantize(
305+
model, weight_config={}, absorb_to_layer={}, folding=True, dataloader=None, calib_func=None, example_inputs=None
306+
):
307+
"""Run TEQ weight-only quantization."""
308+
assert isinstance(model, torch.nn.Module), "only support torch module"
309+
logger.info("TEQ quantizing start.")
310+
if example_inputs is None:
311+
if dataloader is None: # pragma: no cover
312+
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
313+
try:
314+
for idx, (input, label) in enumerate(dataloader):
315+
example_inputs = input
316+
break
317+
except: # pragma: no cover
318+
for idx, input in enumerate(dataloader):
319+
example_inputs = input
320+
break
321+
322+
teq_quantizer = TEQuantizer(model, weight_config, absorb_to_layer, folding, example_inputs)
323+
324+
# 1. wrapper tuning scale to model
325+
teq_quantizer.add_tuning_scale()
326+
327+
# 2. tuning
328+
# custom train function, there calls calib_func
329+
if calib_func: # pragma: no cover
330+
calib_func(teq_quantizer.model)
331+
else:
332+
if dataloader is None: # pragma: no cover
333+
assert False, "Please provide dataloader to train."
334+
teq_quantizer.train(dataloader)
335+
336+
# 3. apply scale to model
337+
teq_quantizer.transform()
338+
339+
# 4. get quantized model
340+
teq_quantizer.quantize()
341+
342+
logger.info("TEQ quantizing done.")
343+
return teq_quantizer.model

neural_compressor/torch/algorithms/weight_only/utility.py

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def qdq_weight_asym(weight, bits=4, quantile=1.0, return_int=False, **kwargs):
149149
zp.unsqueeze_(dim=-1)
150150
weight.div_(scale)
151151
weight.round_()
152+
weight.add_(zp)
152153
weight.clamp_(0, maxq)
153154
keep_scale = kwargs.get("double_quant", False)
154155
if return_int or keep_scale:

neural_compressor/torch/quantization/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
get_default_static_config,
2626
SmoothQuantConfig,
2727
get_default_sq_config,
28+
TEQConfig,
29+
get_default_teq_config,
2830
HQQConfig,
2931
get_default_hqq_config,
3032
)

0 commit comments

Comments
 (0)