@@ -200,9 +200,9 @@ def __init__(
200
200
self ,
201
201
bits : int = 8 ,
202
202
sym : bool = False ,
203
- tokenizer : Any = None ,
203
+ tokenizer : Optional [ Any ] = None ,
204
204
dataset : Optional [str ] = None ,
205
- ratio : Optional [ float ] = None ,
205
+ ratio : float = 1.0 ,
206
206
group_size : Optional [int ] = None ,
207
207
all_layers : Optional [bool ] = None ,
208
208
sensitivity_metric : Optional [str ] = None ,
@@ -213,7 +213,7 @@ def __init__(
213
213
self .sym = sym
214
214
self .tokenizer = tokenizer
215
215
self .dataset = dataset
216
- self .group_size = group_size
216
+ self .group_size = group_size or ( - 1 if bits == 8 else 128 )
217
217
self .ratio = ratio
218
218
self .all_layers = all_layers
219
219
self .sensitivity_metric = sensitivity_metric
@@ -226,9 +226,9 @@ def post_init(self):
226
226
Safety checker that arguments are correct
227
227
"""
228
228
if self .ratio is not None and not (0 <= self .ratio <= 1 ):
229
- raise ValueError ("damp_percent must between 0 and 1." )
229
+ raise ValueError ("` damp_percent` must between 0 and 1." )
230
230
if self .group_size is not None and self .group_size != - 1 and self .group_size <= 0 :
231
- raise ValueError ("group_size must be greater than 0 or equal to -1" )
231
+ raise ValueError ("` group_size` must be greater than 0 or equal to -1" )
232
232
if self .dataset is not None and isinstance (self .dataset , str ):
233
233
if self .dataset not in ["wikitext2" , "c4" , "c4-new" , "ptb" , "ptb-new" ]:
234
234
raise ValueError (
@@ -239,6 +239,16 @@ def post_init(self):
239
239
if self .bits not in [4 , 8 ]:
240
240
raise ValueError (f"Only support quantization to [4,8] bits but found { self .bits } " )
241
241
242
+ if self .bits == 8 :
243
+ if self .ratio != 1 :
244
+ raise ValueError (
245
+ f"For 8-bit quantization, `ratio` is expected to be set to 1.0, but was set to { self .ratio } "
246
+ )
247
+ if self .group_size != - 1 :
248
+ raise ValueError (
249
+ f"For 8-bit quantization, `group_size` is expected to be set to -1, but was set to { self .group_size } "
250
+ )
251
+
242
252
243
253
def _check_default_4bit_configs (config : PretrainedConfig ):
244
254
return _DEFAULT_4BIT_CONFIGS .get (config .name_or_path , None )
0 commit comments