Skip to content

Commit eebb5a6

Browse files
authored
Merge branch 'master' into yuwenzho/refactor_rtn_hqq_awq
2 parents 1aacaa3 + 76b4069 commit eebb5a6

File tree

5 files changed

+257
-91
lines changed

5 files changed

+257
-91
lines changed

neural_compressor/torch/algorithms/weight_only/teq.py

+64-46
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,26 @@
1616
# limitations under the License.
1717
#
1818

19+
import copy
20+
from typing import Any
21+
1922
import torch
2023
import transformers
2124

25+
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
2226
from neural_compressor.torch.utils import get_device, logger
2327

2428
from .modules import MulLinear, TEQLinearFakeQuant
2529
from .utility import get_module, quant_tensor, set_module
2630

27-
__all__ = ["teq_quantize", "TEQuantizer"]
31+
__all__ = ["TrainableEquivalentTransformation", "TEQuantizer"]
32+
2833

34+
class TrainableEquivalentTransformation:
35+
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
2936

30-
class TEQuantizer:
31-
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
37+
_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
38+
_PREPARE_ATTRS_PREFIX = "_prepare_"
3239

3340
def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
3441
"""
@@ -41,16 +48,20 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
4148
self.folding = folding
4249
self.example_inputs = example_inputs
4350
self.device = self._get_device()
44-
self.dtype = self._get_dtype()
45-
self.model.eval()
4651
self.trained_alphas = {}
4752
self.absorb_to_layer = absorb_to_layer
53+
self._post_initialized = False
54+
55+
def _post_init(self):
56+
self.dtype = self._get_dtype()
57+
self.model.to(self.device)
58+
self.model.eval()
59+
self._post_initialized = True
4860

4961
def _get_device(self):
5062
"""Get the model device
5163
:return:Model device."""
5264
device = get_device()
53-
self.model.to(device)
5465
return device
5566

5667
def _get_dtype(self):
@@ -62,6 +73,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
6273
to the paper for more details
6374
:param sqrt_w_init: use sqrt weight to init."""
6475

76+
if not self._post_initialized:
77+
self._post_init()
6578
# freeze model.
6679
for n, p in self.model.named_parameters():
6780
p.requires_grad = False
@@ -117,6 +130,9 @@ def add_tuning_scale(self, sqrt_w_init=False):
117130
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
118131
)
119132
set_module(self.model, n, wrapper_module)
133+
# Attach the weight config captured at prepare stage to the model
134+
self.model._weight_config = self.weight_config
135+
self.model._trained_alphas = self.trained_alphas
120136

121137
@torch.no_grad()
122138
def _absorb_scales(self, layer, scale, layer_name=""):
@@ -204,6 +220,8 @@ def _scale_layer_weight(self, layer, scale): ##input channel
204220
@torch.no_grad()
205221
def transform(self):
206222
"""Apply alpha/scale."""
223+
if not self._post_initialized:
224+
self._post_init()
207225
for ln_name, layer_names in self.absorb_to_layer.items():
208226
module = get_module(self.model, ln_name)
209227
scale = self.trained_alphas[ln_name]
@@ -309,43 +327,43 @@ def save(self, save_scale_file="", save_state_dict_file=""):
309327
torch.save(self.model.state_dict(), save_state_dict_file)
310328

311329

312-
def teq_quantize(
313-
model, weight_config={}, absorb_to_layer={}, folding=True, dataloader=None, calib_func=None, example_inputs=None
314-
):
315-
"""Run TEQ weight-only quantization."""
316-
assert isinstance(model, torch.nn.Module), "only support torch module"
317-
logger.info("TEQ quantizing start.")
318-
if example_inputs is None:
319-
if dataloader is None: # pragma: no cover
320-
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
321-
try:
322-
for idx, (input, label) in enumerate(dataloader):
323-
example_inputs = input
324-
break
325-
except: # pragma: no cover
326-
for idx, input in enumerate(dataloader):
327-
example_inputs = input
328-
break
329-
330-
teq_quantizer = TEQuantizer(model, weight_config, absorb_to_layer, folding, example_inputs)
331-
332-
# 1. wrapper tuning scale to model
333-
teq_quantizer.add_tuning_scale()
334-
335-
# 2. tuning
336-
# custom train function, there calls calib_func
337-
if calib_func: # pragma: no cover
338-
calib_func(teq_quantizer.model)
339-
else:
340-
if dataloader is None: # pragma: no cover
341-
assert False, "Please provide dataloader to train."
342-
teq_quantizer.train(dataloader)
343-
344-
# 3. apply scale to model
345-
teq_quantizer.transform()
346-
347-
# 4. get quantized model
348-
teq_quantizer.quantize()
349-
350-
logger.info("TEQ quantizing done.")
351-
return teq_quantizer.model
330+
class TEQuantizer(Quantizer):
331+
332+
def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
333+
super().__init__(quant_config=quant_config)
334+
self.folding = folding
335+
self.absorb_to_layer = absorb_to_layer
336+
self.example_inputs = example_inputs
337+
self._quantizer = TrainableEquivalentTransformation(
338+
model=None,
339+
weight_config=quant_config,
340+
absorb_to_layer=absorb_to_layer,
341+
folding=folding,
342+
example_inputs=example_inputs,
343+
)
344+
345+
def prepare(self, model, *args, **kwargs):
346+
"""Prepares a given model for quantization.
347+
348+
Args:
349+
model: A float model to be quantized.
350+
Returns:
351+
A prepared model.
352+
"""
353+
float_model = model
354+
assert isinstance(model, torch.nn.Module), "only support torch module"
355+
self._quantizer.model = float_model
356+
logger.info("TEQ quantizing start.")
357+
self._quantizer.add_tuning_scale()
358+
for attr in self._quantizer._PREPARE_ATTRS:
359+
setattr(float_model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, getattr(self._quantizer, attr))
360+
return float_model
361+
362+
def convert(self, model, *args: Any, **kwargs: Any):
363+
for attr in self._quantizer._PREPARE_ATTRS:
364+
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
365+
self._quantizer.model = model
366+
self._quantizer.transform()
367+
self._quantizer.quantize()
368+
logger.info("TEQ quantizing done.")
369+
return self._quantizer.model

neural_compressor/torch/quantization/algorithm_entry.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,17 @@ def awq_quantize_entry(
294294
###################### TEQ Algo Entry ##################################
295295
@register_algo(name=TEQ)
296296
def teq_quantize_entry(
297-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], *args, **kwargs
297+
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], mode: Mode, *args, **kwargs
298298
) -> torch.nn.Module:
299-
from neural_compressor.torch.algorithms.weight_only.teq import teq_quantize
299+
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer
300300

301301
logger.info("Quantize model with the TEQ algorithm.")
302302
weight_config = {}
303303
absorb_to_layer = {}
304304
example_inputs = kwargs.get("example_inputs", None)
305305
assert example_inputs is not None, "Please provide example_inputs for TEQ quantization."
306-
calib_func = kwargs.get("run_fn", None)
306+
run_fn = kwargs.get("run_fn", None)
307+
inplace = kwargs.get("inplace", True)
307308
folding = True
308309
for (op_name, op_type), quant_config in configs_mapping.items():
309310
if quant_config.dtype == "fp32":
@@ -328,16 +329,10 @@ def teq_quantize_entry(
328329
absorb_to_layer = quant_config.absorb_to_layer
329330
folding = quant_config.folding
330331
assert isinstance(model, torch.nn.Module), "only support torch module"
331-
332-
model = teq_quantize(
333-
model,
334-
example_inputs=example_inputs,
335-
folding=folding,
336-
absorb_to_layer=absorb_to_layer,
337-
calib_func=calib_func,
338-
weight_config=weight_config,
332+
quantizer = TEQuantizer(
333+
quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs
339334
)
340-
logger.info("TEQ quantization done.")
335+
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
341336
return model
342337

343338

neural_compressor/torch/utils/auto_accelerator.py

-4
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def empty_cache(self):
137137
def synchronize(self):
138138
pass
139139

140-
@abstractmethod
141140
def mark_step(self):
142141
pass
143142

@@ -175,9 +174,6 @@ def empty_cache(self):
175174
def synchronize(self):
176175
pass
177176

178-
def mark_step(self):
179-
pass
180-
181177

182178
@register_accelerator(name="cuda", priority=PRIORITY_CUDA)
183179
class CUDA_Accelerator(Auto_Accelerator):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import copy
2+
import unittest
3+
4+
import torch
5+
import transformers
6+
7+
from neural_compressor.common import logger
8+
from neural_compressor.torch.algorithms.weight_only.teq import TEQuantizer
9+
from neural_compressor.torch.quantization import quantize
10+
11+
12+
def generate_random_corpus(nsamples=32):
13+
meta_data = []
14+
for _ in range(nsamples):
15+
inp = torch.ones([1, 512], dtype=torch.long)
16+
tar = torch.ones([1, 512], dtype=torch.long)
17+
meta_data.append((inp, tar))
18+
return meta_data
19+
20+
21+
def train(
22+
model,
23+
train_steps=100,
24+
lr=1e-3,
25+
warmup_ratio=0.05,
26+
gradient_accumulation_steps=1,
27+
logging_steps=10,
28+
betas=[0.9, 0.9],
29+
weight_decay=0,
30+
lr_scheduler_type="linear",
31+
):
32+
"""Train function."""
33+
trained_alphas_list = [torch.ones([128], requires_grad=True)]
34+
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)
35+
36+
lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
37+
name=lr_scheduler_type,
38+
optimizer=optimizer,
39+
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
40+
num_training_steps=train_steps // gradient_accumulation_steps,
41+
)
42+
43+
logger.info("start training")
44+
model.train()
45+
global_steps = 0
46+
dataloader = generate_random_corpus()
47+
while global_steps <= train_steps:
48+
for inputs in dataloader:
49+
if isinstance(inputs, torch.Tensor):
50+
input_id = inputs
51+
elif isinstance(inputs, dict):
52+
input_id = inputs["input_ids"]
53+
else:
54+
input_id = inputs[0]
55+
output = model(input_id, labels=input_id)
56+
loss = output[0] / gradient_accumulation_steps
57+
loss.backward()
58+
global_steps += 1
59+
60+
if global_steps % logging_steps == 0:
61+
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))
62+
63+
if global_steps % gradient_accumulation_steps == 0:
64+
optimizer.step()
65+
optimizer.zero_grad()
66+
lr_scheduler.step()
67+
68+
if global_steps >= train_steps: # pragma: no cover
69+
break
70+
71+
logger.info("finish training")
72+
model.eval()
73+
return None
74+
75+
76+
class TestTEQWeightOnlyQuant(unittest.TestCase):
77+
@classmethod
78+
def setUpClass(self):
79+
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(
80+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
81+
torchscript=True,
82+
)
83+
self.gptj.seqlen = 512
84+
85+
def train_func(self):
86+
pass
87+
88+
def test_teq(self):
89+
example_inputs = torch.ones([1, 512], dtype=torch.long)
90+
test_input = torch.ones([1, 512], dtype=torch.long)
91+
model = copy.deepcopy(self.gptj)
92+
out0 = model(test_input)
93+
94+
weight_config = {
95+
# 'op_name': (bit, group_size, scheme)
96+
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
97+
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
98+
}
99+
absorb_dict = {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]}
100+
101+
quantizer = TEQuantizer(
102+
quant_config=weight_config, folding=True, absorb_to_layer=absorb_dict, example_inputs=example_inputs
103+
)
104+
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
105+
out1 = model(test_input)
106+
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))
107+
108+
quant_config = {
109+
"teq": {
110+
"global": {
111+
"dtype": "fp32",
112+
},
113+
"local": {
114+
"transformer.h.0.mlp.fc_in": {
115+
"dtype": "int",
116+
"bits": 8,
117+
"group_size": -1,
118+
"use_sym": True,
119+
"folding": True,
120+
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
121+
},
122+
"transformer.h.0.mlp.fc_out": {
123+
"dtype": "int",
124+
"bits": 4,
125+
"group_size": 32,
126+
"use_sym": False,
127+
"folding": True,
128+
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
129+
},
130+
},
131+
}
132+
}
133+
qdq_model = quantize(
134+
model=copy.deepcopy(self.gptj), quant_config=quant_config, run_fn=train, example_inputs=example_inputs
135+
)
136+
self.assertTrue(isinstance(qdq_model, torch.nn.Module))
137+
out2 = qdq_model(test_input)
138+
self.assertTrue(torch.allclose(out1[0], out2[0]))
139+
self.assertTrue(torch.allclose(out2[0], out0[0], atol=0.03))
140+
141+
142+
if __name__ == "__main__":
143+
unittest.main()

0 commit comments

Comments
 (0)