16
16
# limitations under the License.
17
17
#
18
18
19
+ import copy
20
+ from typing import Any
21
+
19
22
import torch
20
23
import transformers
21
24
25
+ from neural_compressor .torch .algorithms .base_algorithm import Quantizer
22
26
from neural_compressor .torch .utils import get_device , logger
23
27
24
28
from .modules import MulLinear , TEQLinearFakeQuant
25
29
from .utility import get_module , quant_tensor , set_module
26
30
27
- __all__ = ["teq_quantize" , "TEQuantizer" ]
31
+ __all__ = ["TrainableEquivalentTransformation" , "TEQuantizer" ]
32
+
28
33
34
+ class TrainableEquivalentTransformation :
35
+ """Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""
29
36
30
- class TEQuantizer :
31
- """Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."" "
37
+ _PREPARE_ATTRS : list [ str ] = [ "weight_config" , "trained_alphas" ]
38
+ _PREPARE_ATTRS_PREFIX = "_prepare_ "
32
39
33
40
def __init__ (self , model , weight_config = {}, absorb_to_layer = {}, folding = True , example_inputs = None ):
34
41
"""
@@ -41,16 +48,20 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
41
48
self .folding = folding
42
49
self .example_inputs = example_inputs
43
50
self .device = self ._get_device ()
44
- self .dtype = self ._get_dtype ()
45
- self .model .eval ()
46
51
self .trained_alphas = {}
47
52
self .absorb_to_layer = absorb_to_layer
53
+ self ._post_initialized = False
54
+
55
+ def _post_init (self ):
56
+ self .dtype = self ._get_dtype ()
57
+ self .model .to (self .device )
58
+ self .model .eval ()
59
+ self ._post_initialized = True
48
60
49
61
def _get_device (self ):
50
62
"""Get the model device
51
63
:return:Model device."""
52
64
device = get_device ()
53
- self .model .to (device )
54
65
return device
55
66
56
67
def _get_dtype (self ):
@@ -62,6 +73,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
62
73
to the paper for more details
63
74
:param sqrt_w_init: use sqrt weight to init."""
64
75
76
+ if not self ._post_initialized :
77
+ self ._post_init ()
65
78
# freeze model.
66
79
for n , p in self .model .named_parameters ():
67
80
p .requires_grad = False
@@ -117,6 +130,9 @@ def add_tuning_scale(self, sqrt_w_init=False):
117
130
orig_layer = m , alpha = alpha , num_bits = num_bits , group_size = group_size , scheme = scheme
118
131
)
119
132
set_module (self .model , n , wrapper_module )
133
+ # Attach the weight config captured at prepare stage to the model
134
+ self .model ._weight_config = self .weight_config
135
+ self .model ._trained_alphas = self .trained_alphas
120
136
121
137
@torch .no_grad ()
122
138
def _absorb_scales (self , layer , scale , layer_name = "" ):
@@ -204,6 +220,8 @@ def _scale_layer_weight(self, layer, scale): ##input channel
204
220
@torch .no_grad ()
205
221
def transform (self ):
206
222
"""Apply alpha/scale."""
223
+ if not self ._post_initialized :
224
+ self ._post_init ()
207
225
for ln_name , layer_names in self .absorb_to_layer .items ():
208
226
module = get_module (self .model , ln_name )
209
227
scale = self .trained_alphas [ln_name ]
@@ -309,43 +327,43 @@ def save(self, save_scale_file="", save_state_dict_file=""):
309
327
torch .save (self .model .state_dict (), save_state_dict_file )
310
328
311
329
312
- def teq_quantize (
313
- model , weight_config = {}, absorb_to_layer = {}, folding = True , dataloader = None , calib_func = None , example_inputs = None
314
- ):
315
- """Run TEQ weight-only quantization."""
316
- assert isinstance ( model , torch . nn . Module ), "only support torch module"
317
- logger . info ( "TEQ quantizing start." )
318
- if example_inputs is None :
319
- if dataloader is None : # pragma: no cover
320
- assert False , "Please provide dataloader or example_inputs for TEQ algorithm."
321
- try :
322
- for idx , ( input , label ) in enumerate ( dataloader ):
323
- example_inputs = input
324
- break
325
- except : # pragma: no cover
326
- for idx , input in enumerate ( dataloader ):
327
- example_inputs = input
328
- break
329
-
330
- teq_quantizer = TEQuantizer ( model , weight_config , absorb_to_layer , folding , example_inputs )
331
-
332
- # 1. wrapper tuning scale to model
333
- teq_quantizer . add_tuning_scale ()
334
-
335
- # 2. tuning
336
- # custom train function, there calls calib_func
337
- if calib_func : # pragma: no cover
338
- calib_func ( teq_quantizer . model )
339
- else :
340
- if dataloader is None : # pragma: no cover
341
- assert False , "Please provide dataloader to train."
342
- teq_quantizer . train ( dataloader )
343
-
344
- # 3. apply scale to model
345
- teq_quantizer . transform ()
346
-
347
- # 4. get quantized model
348
- teq_quantizer . quantize ()
349
-
350
- logger .info ("TEQ quantizing done." )
351
- return teq_quantizer .model
330
+ class TEQuantizer ( Quantizer ):
331
+
332
+ def __init__ ( self , quant_config , folding , absorb_to_layer , example_inputs ):
333
+ super (). __init__ ( quant_config = quant_config )
334
+ self . folding = folding
335
+ self . absorb_to_layer = absorb_to_layer
336
+ self . example_inputs = example_inputs
337
+ self . _quantizer = TrainableEquivalentTransformation (
338
+ model = None ,
339
+ weight_config = quant_config ,
340
+ absorb_to_layer = absorb_to_layer ,
341
+ folding = folding ,
342
+ example_inputs = example_inputs ,
343
+ )
344
+
345
+ def prepare ( self , model , * args , ** kwargs ):
346
+ """Prepares a given model for quantization.
347
+
348
+ Args:
349
+ model: A float model to be quantized.
350
+ Returns:
351
+ A prepared model.
352
+ """
353
+ float_model = model
354
+ assert isinstance ( model , torch . nn . Module ), "only support torch module"
355
+ self . _quantizer . model = float_model
356
+ logger . info ( "TEQ quantizing start." )
357
+ self . _quantizer . add_tuning_scale ()
358
+ for attr in self . _quantizer . _PREPARE_ATTRS :
359
+ setattr ( float_model , self . _quantizer . _PREPARE_ATTRS_PREFIX + attr , getattr ( self . _quantizer , attr ))
360
+ return float_model
361
+
362
+ def convert ( self , model , * args : Any , ** kwargs : Any ):
363
+ for attr in self . _quantizer . _PREPARE_ATTRS :
364
+ setattr ( self . _quantizer , attr , getattr ( model , self . _quantizer . _PREPARE_ATTRS_PREFIX + attr , None ))
365
+ self . _quantizer . model = model
366
+ self . _quantizer . transform ()
367
+ self . _quantizer . quantize ()
368
+ logger .info ("TEQ quantizing done." )
369
+ return self . _quantizer .model
0 commit comments