|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import time |
| 16 | + |
15 | 17 | import torch
|
16 | 18 | from auto_round import AutoRound # pylint: disable=E0401
|
17 | 19 | from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
|
| 20 | +from auto_round.utils import get_block_names # pylint: disable=E0401 |
18 | 21 |
|
| 22 | +from neural_compressor.torch.algorithms import Quantizer |
19 | 23 | from neural_compressor.torch.utils import logger
|
20 | 24 |
|
21 | 25 |
|
| 26 | +class AutoRoundQuantizer(Quantizer): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + weight_config: dict = {}, |
| 30 | + enable_full_range: bool = False, |
| 31 | + batch_size: int = 8, |
| 32 | + amp: bool = True, |
| 33 | + device=None, |
| 34 | + lr_scheduler=None, |
| 35 | + use_quant_input: bool = True, |
| 36 | + enable_minmax_tuning: bool = True, |
| 37 | + lr: float = None, |
| 38 | + minmax_lr: float = None, |
| 39 | + low_gpu_mem_usage: bool = True, |
| 40 | + iters: int = 200, |
| 41 | + seqlen: int = 2048, |
| 42 | + n_samples: int = 512, |
| 43 | + sampler: str = "rand", |
| 44 | + seed: int = 42, |
| 45 | + n_blocks: int = 1, |
| 46 | + gradient_accumulate_steps: int = 1, |
| 47 | + not_use_best_mse: bool = False, |
| 48 | + dynamic_max_gap: int = -1, |
| 49 | + scale_dtype="fp32", |
| 50 | + ): |
| 51 | + """Init a AutQRoundQuantizer object. |
| 52 | +
|
| 53 | + Args: |
| 54 | + weight_config (dict): Configuration for weight quantization (default is an empty dictionary). |
| 55 | + weight_config={ |
| 56 | + 'layer1':##layer_name |
| 57 | + { |
| 58 | + 'data_type': 'int', |
| 59 | + 'bits': 4, |
| 60 | + 'group_size': 32, |
| 61 | + 'sym': False, |
| 62 | + } |
| 63 | + ... |
| 64 | + } |
| 65 | + keys: |
| 66 | + data_type (str): The data type to be used (default is "int"). |
| 67 | + bits (int): Number of bits for quantization (default is 4). |
| 68 | + group_size (int): Size of the quantization group (default is 128). |
| 69 | + sym (bool): Whether to use symmetric quantization. (default is None). |
| 70 | + enable_full_range (bool): Whether to enable full range quantization (default is False). |
| 71 | + batch_size (int): Batch size for training (default is 8). |
| 72 | + amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set. |
| 73 | + device: The device to be used for tuning (default is None). Automatically detect and set. |
| 74 | + lr_scheduler: The learning rate scheduler to be used. |
| 75 | + use_quant_input (bool): Whether to use quantized input data (default is True). |
| 76 | + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). |
| 77 | + lr (float): The learning rate (default is 0.005). |
| 78 | + minmax_lr (float): The learning rate for min-max tuning (default is None). |
| 79 | + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). |
| 80 | + iters (int): Number of iterations (default is 200). |
| 81 | + seqlen (int): Length of the sequence. |
| 82 | + n_samples (int): Number of samples (default is 512). |
| 83 | + sampler (str): The sampling method (default is "rand"). |
| 84 | + seed (int): The random seed (default is 42). |
| 85 | + n_blocks (int): Number of blocks (default is 1). |
| 86 | + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). |
| 87 | + not_use_best_mse (bool): Whether to use mean squared error (default is False). |
| 88 | + dynamic_max_gap (int): The dynamic maximum gap (default is -1). |
| 89 | + scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels |
| 90 | + have different choices. |
| 91 | + """ |
| 92 | + |
| 93 | + self.tokenizer = None |
| 94 | + self.weight_config = weight_config |
| 95 | + self.enable_full_range = enable_full_range |
| 96 | + self.batch_size = batch_size |
| 97 | + self.amp = amp |
| 98 | + self.device = device |
| 99 | + self.lr_scheduler = lr_scheduler |
| 100 | + self.use_quant_input = use_quant_input |
| 101 | + self.enable_minmax_tuning = enable_minmax_tuning |
| 102 | + self.lr = lr |
| 103 | + self.minmax_lr = minmax_lr |
| 104 | + self.low_gpu_mem_usage = low_gpu_mem_usage |
| 105 | + self.iters = iters |
| 106 | + self.seqlen = seqlen |
| 107 | + self.n_samples = n_samples |
| 108 | + self.sampler = sampler |
| 109 | + self.seed = seed |
| 110 | + self.n_blocks = n_blocks |
| 111 | + self.gradient_accumulate_steps = gradient_accumulate_steps |
| 112 | + self.not_use_best_mse = not_use_best_mse |
| 113 | + self.dynamic_max_gap = dynamic_max_gap |
| 114 | + self.data_type = "int" |
| 115 | + self.scale_dtype = scale_dtype |
| 116 | + |
| 117 | + def prepare(self, model: torch.nn.Module, *args, **kwargs): |
| 118 | + """Prepares a given model for quantization. |
| 119 | + Args: |
| 120 | + model (torch.nn.Module): The model to be prepared. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + A prepared model. |
| 124 | + """ |
| 125 | + self.rounder = AutoRoundProcessor( |
| 126 | + model=model, |
| 127 | + tokenizer=None, |
| 128 | + weight_config=self.weight_config, |
| 129 | + enable_full_range=self.enable_full_range, |
| 130 | + batch_size=self.batch_size, |
| 131 | + amp=self.amp, |
| 132 | + device=self.device, |
| 133 | + lr_scheduler=self.lr_scheduler, |
| 134 | + use_quant_input=self.use_quant_input, |
| 135 | + enable_minmax_tuning=self.enable_minmax_tuning, |
| 136 | + lr=self.lr, |
| 137 | + minmax_lr=self.minmax_lr, |
| 138 | + low_gpu_mem_usage=self.low_gpu_mem_usage, |
| 139 | + iters=self.iters, |
| 140 | + seqlen=self.seqlen, |
| 141 | + n_samples=self.n_samples, |
| 142 | + sampler=self.sampler, |
| 143 | + seed=self.seed, |
| 144 | + n_blocks=self.n_blocks, |
| 145 | + gradient_accumulate_steps=self.gradient_accumulate_steps, |
| 146 | + not_use_best_mse=self.not_use_best_mse, |
| 147 | + dynamic_max_gap=self.dynamic_max_gap, |
| 148 | + data_type=self.data_type, |
| 149 | + scale_dtype=self.scale_dtype, |
| 150 | + ) |
| 151 | + self.rounder.prepare() |
| 152 | + return model |
| 153 | + |
| 154 | + def convert(self, model: torch.nn.Module, *args, **kwargs): |
| 155 | + model, weight_config = self.rounder.convert() |
| 156 | + model.autoround_config = weight_config |
| 157 | + return model |
| 158 | + |
| 159 | + |
22 | 160 | @torch.no_grad()
|
23 | 161 | def get_autoround_default_run_fn(
|
24 | 162 | model,
|
@@ -94,140 +232,95 @@ def get_autoround_default_run_fn(
|
94 | 232 | )
|
95 | 233 |
|
96 | 234 |
|
97 |
| -class InputCaptureModule(torch.nn.Module): |
| 235 | +class AutoRoundProcessor(AutoRound): |
98 | 236 |
|
99 |
| - def __init__(self) -> None: |
100 |
| - super().__init__() |
101 |
| - self.data_pairs = [] |
102 |
| - self.device = "cpu" |
| 237 | + def prepare(self): |
| 238 | + """Prepares a given model for quantization.""" |
| 239 | + # logger.info("cache block input") |
| 240 | + self.start_time = time.time() |
| 241 | + self.block_names = get_block_names(self.model) |
| 242 | + if len(self.block_names) == 0: |
| 243 | + logger.warning("could not find blocks, exit with original model") |
| 244 | + return |
| 245 | + if self.amp: |
| 246 | + self.model = self.model.to(self.amp_dtype) |
| 247 | + if not self.low_gpu_mem_usage: |
| 248 | + self.model = self.model.to(self.device) |
| 249 | + # inputs = self.cache_block_input(block_names[0], self.n_samples) |
103 | 250 |
|
104 |
| - def forward(self, *args, **kwargs): |
105 |
| - if kwargs and len(args) == 0: |
106 |
| - # Handle cases where input data is a dict |
107 |
| - self.data_pairs.append(kwargs) |
108 |
| - elif args and len(args) == 1: |
109 |
| - # Handle cases where input data is a Tensor |
110 |
| - self.data_pairs.append(args[0]) |
111 |
| - else: |
112 |
| - logger.error("Handle cases where input data is neither a Tensor nor a dict") |
| 251 | + # cache block input |
| 252 | + self.inputs = {} |
| 253 | + self.tmp_block_name = self.block_names[0] |
| 254 | + self._replace_forward() |
113 | 255 |
|
| 256 | + def convert(self): |
| 257 | + """Converts a prepared model to a quantized model.""" |
| 258 | + self._recover_forward() |
| 259 | + inputs = self.inputs[self.tmp_block_name] |
| 260 | + del self.tmp_block_name |
114 | 261 |
|
115 |
| -def recover_dataloader_from_calib_fn(run_fn, run_args): |
116 |
| - input_capture_model = InputCaptureModule() |
117 |
| - input_capture_model.eval() |
118 |
| - run_fn(input_capture_model, *run_args) |
119 |
| - dataloader = torch.utils.data.DataLoader(input_capture_model.data_pairs) |
120 |
| - return dataloader |
| 262 | + del self.inputs |
| 263 | + if "input_ids" in inputs.keys(): |
| 264 | + dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type)) |
| 265 | + total_samples = inputs["input_ids"].shape[dim] |
| 266 | + self.n_samples = total_samples |
| 267 | + if total_samples < self.train_bs: |
| 268 | + self.train_bs = total_samples |
| 269 | + logger.warning(f"force the train batch size to {total_samples} ") |
| 270 | + self.model = self.model.to("cpu") |
| 271 | + torch.cuda.empty_cache() |
| 272 | + self.qdq_weight_round( |
| 273 | + self.model, |
| 274 | + inputs, |
| 275 | + self.block_names, |
| 276 | + n_blocks=self.n_blocks, |
| 277 | + device=self.device, |
| 278 | + ) |
| 279 | + for n, m in self.model.named_modules(): |
| 280 | + if n in self.weight_config.keys(): |
| 281 | + if hasattr(m, "scale"): |
| 282 | + self.weight_config[n]["scale"] = m.scale |
| 283 | + self.weight_config[n]["zp"] = m.zp |
| 284 | + if self.group_size <= 0: |
| 285 | + self.weight_config[n]["g_idx"] = torch.tensor( |
| 286 | + [0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" |
| 287 | + ) |
| 288 | + else: |
| 289 | + self.weight_config[n]["g_idx"] = torch.tensor( |
| 290 | + [i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" |
| 291 | + ) |
| 292 | + delattr(m, "scale") |
| 293 | + delattr(m, "zp") |
| 294 | + else: |
| 295 | + self.weight_config[n]["data_type"] = "float" |
| 296 | + if self.amp_dtype == torch.bfloat16: |
| 297 | + self.weight_config[n]["data_type"] = "bfloat" |
| 298 | + self.weight_config[n]["bits"] = 16 |
| 299 | + self.weight_config[n]["group_size"] = None |
| 300 | + self.weight_config[n]["sym"] = None |
121 | 301 |
|
| 302 | + end_time = time.time() |
| 303 | + cost_time = end_time - self.start_time |
| 304 | + logger.info(f"quantization tuning time {cost_time}") |
| 305 | + ## dump a summary |
| 306 | + quantized_layers = [] |
| 307 | + unquantized_layers = [] |
| 308 | + for n, m in self.model.named_modules(): |
| 309 | + if isinstance(m, tuple(self.supported_types)): |
| 310 | + if self.weight_config[n]["bits"] == 16: |
| 311 | + unquantized_layers.append(n) |
| 312 | + else: |
| 313 | + quantized_layers.append(n) |
| 314 | + summary_info = ( |
| 315 | + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" |
| 316 | + ) |
| 317 | + if len(unquantized_layers) > 0: |
| 318 | + summary_info += f", {unquantized_layers} have not been quantized" |
122 | 319 |
|
123 |
| -def autoround_quantize( |
124 |
| - model, |
125 |
| - weight_config: dict = {}, |
126 |
| - enable_full_range: bool = False, ##for symmetric, TODO support later |
127 |
| - batch_size: int = 8, |
128 |
| - amp: bool = True, |
129 |
| - device=None, |
130 |
| - lr_scheduler=None, |
131 |
| - use_quant_input: bool = True, |
132 |
| - enable_minmax_tuning: bool = True, |
133 |
| - lr: float = None, |
134 |
| - minmax_lr: float = None, |
135 |
| - low_gpu_mem_usage: bool = True, |
136 |
| - iters: int = 200, |
137 |
| - seqlen: int = 2048, |
138 |
| - n_samples: int = 512, |
139 |
| - sampler: str = "rand", |
140 |
| - seed: int = 42, |
141 |
| - n_blocks: int = 1, |
142 |
| - gradient_accumulate_steps: int = 1, |
143 |
| - not_use_best_mse: bool = False, |
144 |
| - dynamic_max_gap: int = -1, |
145 |
| - scale_dtype="fp16", |
146 |
| - run_fn=None, |
147 |
| - run_args=None, |
148 |
| -): |
149 |
| - """The entry point of the autoround weight-only quantization. |
150 |
| - Args: |
151 |
| - model: The PyTorch model to be quantized. |
152 |
| - weight_config (dict): Configuration for weight quantization (default is an empty dictionary). |
153 |
| - weight_config={ |
154 |
| - 'layer1':##layer_name |
155 |
| - { |
156 |
| - 'data_type': 'int', |
157 |
| - 'bits': 4, |
158 |
| - 'group_size': 32, |
159 |
| - 'sym': False, |
160 |
| - } |
161 |
| - ... |
162 |
| - } |
163 |
| - keys: |
164 |
| - data_type (str): The data type to be used (default is "int"). |
165 |
| - bits (int): Number of bits for quantization (default is 4). |
166 |
| - group_size (int): Size of the quantization group (default is 128). |
167 |
| - sym (bool): Whether to use symmetric quantization. (default is None). |
168 |
| - enable_full_range (bool): Whether to enable full range quantization (default is False). |
169 |
| - batch_size (int): Batch size for training (default is 8). |
170 |
| - amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set. |
171 |
| - device: The device to be used for tuning (default is None). Automatically detect and set. |
172 |
| - lr_scheduler: The learning rate scheduler to be used. |
173 |
| - use_quant_input (bool): Whether to use quantized input data (default is True). |
174 |
| - enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). |
175 |
| - lr (float): The learning rate (default is 0.005). |
176 |
| - minmax_lr (float): The learning rate for min-max tuning (default is None). |
177 |
| - low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). |
178 |
| - iters (int): Number of iterations (default is 200). |
179 |
| - seqlen (int): Length of the sequence. |
180 |
| - n_samples (int): Number of samples (default is 512). |
181 |
| - sampler (str): The sampling method (default is "rand"). |
182 |
| - seed (int): The random seed (default is 42). |
183 |
| - n_blocks (int): Number of blocks (default is 1). |
184 |
| - gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). |
185 |
| - not_use_best_mse (bool): Whether to use mean squared error (default is False). |
186 |
| - dynamic_max_gap (int): The dynamic maximum gap (default is -1). |
187 |
| - scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels |
188 |
| - have different choices. |
189 |
| - run_fn: a calibration function for calibrating the model. Defaults to None. |
190 |
| - run_args: positional arguments for `run_fn`. Defaults to None. |
191 |
| -
|
192 |
| - Returns: |
193 |
| - The quantized model. |
194 |
| - """ |
195 |
| - if run_fn is None or run_fn == get_autoround_default_run_fn: |
196 |
| - assert run_args is not None, "Please provide tokenizer for AutoRound default calibration." |
197 |
| - run_fn = get_autoround_default_run_fn |
198 |
| - dataloader = recover_dataloader_from_calib_fn(run_fn, run_args) |
199 |
| - |
200 |
| - rounder = AutoRound( |
201 |
| - model=model, |
202 |
| - tokenizer=None, |
203 |
| - bits=4, |
204 |
| - group_size=128, |
205 |
| - sym=False, |
206 |
| - weight_config=weight_config, |
207 |
| - enable_full_range=enable_full_range, ##for symmetric, TODO support later |
208 |
| - batch_size=batch_size, |
209 |
| - amp=amp, |
210 |
| - device=device, |
211 |
| - lr_scheduler=lr_scheduler, |
212 |
| - dataloader=dataloader, |
213 |
| - use_quant_input=use_quant_input, |
214 |
| - enable_minmax_tuning=enable_minmax_tuning, |
215 |
| - lr=lr, |
216 |
| - minmax_lr=minmax_lr, |
217 |
| - low_gpu_mem_usage=low_gpu_mem_usage, |
218 |
| - iters=iters, |
219 |
| - seqlen=seqlen, |
220 |
| - n_samples=n_samples, |
221 |
| - sampler=sampler, |
222 |
| - seed=seed, |
223 |
| - n_blocks=n_blocks, |
224 |
| - gradient_accumulate_steps=gradient_accumulate_steps, |
225 |
| - not_use_best_mse=not_use_best_mse, |
226 |
| - dynamic_max_gap=dynamic_max_gap, |
227 |
| - data_type="int", |
228 |
| - scale_dtype=scale_dtype, |
229 |
| - run_fn=run_fn, |
230 |
| - run_args=run_args, |
231 |
| - ) |
232 |
| - qdq_model, weight_config = rounder.quantize() |
233 |
| - return qdq_model, weight_config |
| 320 | + logger.info(summary_info) |
| 321 | + if len(unquantized_layers) > 0: |
| 322 | + logger.info(f"Summary: {unquantized_layers} have not been quantized") |
| 323 | + |
| 324 | + self.quantized = True |
| 325 | + self.model = self.model.to(self.model_orig_dtype) |
| 326 | + return self.model, self.weight_config |
0 commit comments