Skip to content

Commit e3c736f

Browse files
authored
Migrate AutoRound to Torch new 3x API (#1763)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 044e6db commit e3c736f

File tree

3 files changed

+288
-167
lines changed

3 files changed

+288
-167
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

+224-131
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,151 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import time
16+
1517
import torch
1618
from auto_round import AutoRound # pylint: disable=E0401
1719
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
20+
from auto_round.utils import get_block_names # pylint: disable=E0401
1821

22+
from neural_compressor.torch.algorithms import Quantizer
1923
from neural_compressor.torch.utils import logger
2024

2125

26+
class AutoRoundQuantizer(Quantizer):
27+
def __init__(
28+
self,
29+
weight_config: dict = {},
30+
enable_full_range: bool = False,
31+
batch_size: int = 8,
32+
amp: bool = True,
33+
device=None,
34+
lr_scheduler=None,
35+
use_quant_input: bool = True,
36+
enable_minmax_tuning: bool = True,
37+
lr: float = None,
38+
minmax_lr: float = None,
39+
low_gpu_mem_usage: bool = True,
40+
iters: int = 200,
41+
seqlen: int = 2048,
42+
n_samples: int = 512,
43+
sampler: str = "rand",
44+
seed: int = 42,
45+
n_blocks: int = 1,
46+
gradient_accumulate_steps: int = 1,
47+
not_use_best_mse: bool = False,
48+
dynamic_max_gap: int = -1,
49+
scale_dtype="fp32",
50+
):
51+
"""Init a AutQRoundQuantizer object.
52+
53+
Args:
54+
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
55+
weight_config={
56+
'layer1':##layer_name
57+
{
58+
'data_type': 'int',
59+
'bits': 4,
60+
'group_size': 32,
61+
'sym': False,
62+
}
63+
...
64+
}
65+
keys:
66+
data_type (str): The data type to be used (default is "int").
67+
bits (int): Number of bits for quantization (default is 4).
68+
group_size (int): Size of the quantization group (default is 128).
69+
sym (bool): Whether to use symmetric quantization. (default is None).
70+
enable_full_range (bool): Whether to enable full range quantization (default is False).
71+
batch_size (int): Batch size for training (default is 8).
72+
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
73+
device: The device to be used for tuning (default is None). Automatically detect and set.
74+
lr_scheduler: The learning rate scheduler to be used.
75+
use_quant_input (bool): Whether to use quantized input data (default is True).
76+
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
77+
lr (float): The learning rate (default is 0.005).
78+
minmax_lr (float): The learning rate for min-max tuning (default is None).
79+
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
80+
iters (int): Number of iterations (default is 200).
81+
seqlen (int): Length of the sequence.
82+
n_samples (int): Number of samples (default is 512).
83+
sampler (str): The sampling method (default is "rand").
84+
seed (int): The random seed (default is 42).
85+
n_blocks (int): Number of blocks (default is 1).
86+
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
87+
not_use_best_mse (bool): Whether to use mean squared error (default is False).
88+
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
89+
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
90+
have different choices.
91+
"""
92+
93+
self.tokenizer = None
94+
self.weight_config = weight_config
95+
self.enable_full_range = enable_full_range
96+
self.batch_size = batch_size
97+
self.amp = amp
98+
self.device = device
99+
self.lr_scheduler = lr_scheduler
100+
self.use_quant_input = use_quant_input
101+
self.enable_minmax_tuning = enable_minmax_tuning
102+
self.lr = lr
103+
self.minmax_lr = minmax_lr
104+
self.low_gpu_mem_usage = low_gpu_mem_usage
105+
self.iters = iters
106+
self.seqlen = seqlen
107+
self.n_samples = n_samples
108+
self.sampler = sampler
109+
self.seed = seed
110+
self.n_blocks = n_blocks
111+
self.gradient_accumulate_steps = gradient_accumulate_steps
112+
self.not_use_best_mse = not_use_best_mse
113+
self.dynamic_max_gap = dynamic_max_gap
114+
self.data_type = "int"
115+
self.scale_dtype = scale_dtype
116+
117+
def prepare(self, model: torch.nn.Module, *args, **kwargs):
118+
"""Prepares a given model for quantization.
119+
Args:
120+
model (torch.nn.Module): The model to be prepared.
121+
122+
Returns:
123+
A prepared model.
124+
"""
125+
self.rounder = AutoRoundProcessor(
126+
model=model,
127+
tokenizer=None,
128+
weight_config=self.weight_config,
129+
enable_full_range=self.enable_full_range,
130+
batch_size=self.batch_size,
131+
amp=self.amp,
132+
device=self.device,
133+
lr_scheduler=self.lr_scheduler,
134+
use_quant_input=self.use_quant_input,
135+
enable_minmax_tuning=self.enable_minmax_tuning,
136+
lr=self.lr,
137+
minmax_lr=self.minmax_lr,
138+
low_gpu_mem_usage=self.low_gpu_mem_usage,
139+
iters=self.iters,
140+
seqlen=self.seqlen,
141+
n_samples=self.n_samples,
142+
sampler=self.sampler,
143+
seed=self.seed,
144+
n_blocks=self.n_blocks,
145+
gradient_accumulate_steps=self.gradient_accumulate_steps,
146+
not_use_best_mse=self.not_use_best_mse,
147+
dynamic_max_gap=self.dynamic_max_gap,
148+
data_type=self.data_type,
149+
scale_dtype=self.scale_dtype,
150+
)
151+
self.rounder.prepare()
152+
return model
153+
154+
def convert(self, model: torch.nn.Module, *args, **kwargs):
155+
model, weight_config = self.rounder.convert()
156+
model.autoround_config = weight_config
157+
return model
158+
159+
22160
@torch.no_grad()
23161
def get_autoround_default_run_fn(
24162
model,
@@ -94,140 +232,95 @@ def get_autoround_default_run_fn(
94232
)
95233

96234

97-
class InputCaptureModule(torch.nn.Module):
235+
class AutoRoundProcessor(AutoRound):
98236

99-
def __init__(self) -> None:
100-
super().__init__()
101-
self.data_pairs = []
102-
self.device = "cpu"
237+
def prepare(self):
238+
"""Prepares a given model for quantization."""
239+
# logger.info("cache block input")
240+
self.start_time = time.time()
241+
self.block_names = get_block_names(self.model)
242+
if len(self.block_names) == 0:
243+
logger.warning("could not find blocks, exit with original model")
244+
return
245+
if self.amp:
246+
self.model = self.model.to(self.amp_dtype)
247+
if not self.low_gpu_mem_usage:
248+
self.model = self.model.to(self.device)
249+
# inputs = self.cache_block_input(block_names[0], self.n_samples)
103250

104-
def forward(self, *args, **kwargs):
105-
if kwargs and len(args) == 0:
106-
# Handle cases where input data is a dict
107-
self.data_pairs.append(kwargs)
108-
elif args and len(args) == 1:
109-
# Handle cases where input data is a Tensor
110-
self.data_pairs.append(args[0])
111-
else:
112-
logger.error("Handle cases where input data is neither a Tensor nor a dict")
251+
# cache block input
252+
self.inputs = {}
253+
self.tmp_block_name = self.block_names[0]
254+
self._replace_forward()
113255

256+
def convert(self):
257+
"""Converts a prepared model to a quantized model."""
258+
self._recover_forward()
259+
inputs = self.inputs[self.tmp_block_name]
260+
del self.tmp_block_name
114261

115-
def recover_dataloader_from_calib_fn(run_fn, run_args):
116-
input_capture_model = InputCaptureModule()
117-
input_capture_model.eval()
118-
run_fn(input_capture_model, *run_args)
119-
dataloader = torch.utils.data.DataLoader(input_capture_model.data_pairs)
120-
return dataloader
262+
del self.inputs
263+
if "input_ids" in inputs.keys():
264+
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
265+
total_samples = inputs["input_ids"].shape[dim]
266+
self.n_samples = total_samples
267+
if total_samples < self.train_bs:
268+
self.train_bs = total_samples
269+
logger.warning(f"force the train batch size to {total_samples} ")
270+
self.model = self.model.to("cpu")
271+
torch.cuda.empty_cache()
272+
self.qdq_weight_round(
273+
self.model,
274+
inputs,
275+
self.block_names,
276+
n_blocks=self.n_blocks,
277+
device=self.device,
278+
)
279+
for n, m in self.model.named_modules():
280+
if n in self.weight_config.keys():
281+
if hasattr(m, "scale"):
282+
self.weight_config[n]["scale"] = m.scale
283+
self.weight_config[n]["zp"] = m.zp
284+
if self.group_size <= 0:
285+
self.weight_config[n]["g_idx"] = torch.tensor(
286+
[0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
287+
)
288+
else:
289+
self.weight_config[n]["g_idx"] = torch.tensor(
290+
[i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu"
291+
)
292+
delattr(m, "scale")
293+
delattr(m, "zp")
294+
else:
295+
self.weight_config[n]["data_type"] = "float"
296+
if self.amp_dtype == torch.bfloat16:
297+
self.weight_config[n]["data_type"] = "bfloat"
298+
self.weight_config[n]["bits"] = 16
299+
self.weight_config[n]["group_size"] = None
300+
self.weight_config[n]["sym"] = None
121301

302+
end_time = time.time()
303+
cost_time = end_time - self.start_time
304+
logger.info(f"quantization tuning time {cost_time}")
305+
## dump a summary
306+
quantized_layers = []
307+
unquantized_layers = []
308+
for n, m in self.model.named_modules():
309+
if isinstance(m, tuple(self.supported_types)):
310+
if self.weight_config[n]["bits"] == 16:
311+
unquantized_layers.append(n)
312+
else:
313+
quantized_layers.append(n)
314+
summary_info = (
315+
f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model"
316+
)
317+
if len(unquantized_layers) > 0:
318+
summary_info += f", {unquantized_layers} have not been quantized"
122319

123-
def autoround_quantize(
124-
model,
125-
weight_config: dict = {},
126-
enable_full_range: bool = False, ##for symmetric, TODO support later
127-
batch_size: int = 8,
128-
amp: bool = True,
129-
device=None,
130-
lr_scheduler=None,
131-
use_quant_input: bool = True,
132-
enable_minmax_tuning: bool = True,
133-
lr: float = None,
134-
minmax_lr: float = None,
135-
low_gpu_mem_usage: bool = True,
136-
iters: int = 200,
137-
seqlen: int = 2048,
138-
n_samples: int = 512,
139-
sampler: str = "rand",
140-
seed: int = 42,
141-
n_blocks: int = 1,
142-
gradient_accumulate_steps: int = 1,
143-
not_use_best_mse: bool = False,
144-
dynamic_max_gap: int = -1,
145-
scale_dtype="fp16",
146-
run_fn=None,
147-
run_args=None,
148-
):
149-
"""The entry point of the autoround weight-only quantization.
150-
Args:
151-
model: The PyTorch model to be quantized.
152-
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
153-
weight_config={
154-
'layer1':##layer_name
155-
{
156-
'data_type': 'int',
157-
'bits': 4,
158-
'group_size': 32,
159-
'sym': False,
160-
}
161-
...
162-
}
163-
keys:
164-
data_type (str): The data type to be used (default is "int").
165-
bits (int): Number of bits for quantization (default is 4).
166-
group_size (int): Size of the quantization group (default is 128).
167-
sym (bool): Whether to use symmetric quantization. (default is None).
168-
enable_full_range (bool): Whether to enable full range quantization (default is False).
169-
batch_size (int): Batch size for training (default is 8).
170-
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
171-
device: The device to be used for tuning (default is None). Automatically detect and set.
172-
lr_scheduler: The learning rate scheduler to be used.
173-
use_quant_input (bool): Whether to use quantized input data (default is True).
174-
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
175-
lr (float): The learning rate (default is 0.005).
176-
minmax_lr (float): The learning rate for min-max tuning (default is None).
177-
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
178-
iters (int): Number of iterations (default is 200).
179-
seqlen (int): Length of the sequence.
180-
n_samples (int): Number of samples (default is 512).
181-
sampler (str): The sampling method (default is "rand").
182-
seed (int): The random seed (default is 42).
183-
n_blocks (int): Number of blocks (default is 1).
184-
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
185-
not_use_best_mse (bool): Whether to use mean squared error (default is False).
186-
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
187-
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
188-
have different choices.
189-
run_fn: a calibration function for calibrating the model. Defaults to None.
190-
run_args: positional arguments for `run_fn`. Defaults to None.
191-
192-
Returns:
193-
The quantized model.
194-
"""
195-
if run_fn is None or run_fn == get_autoround_default_run_fn:
196-
assert run_args is not None, "Please provide tokenizer for AutoRound default calibration."
197-
run_fn = get_autoround_default_run_fn
198-
dataloader = recover_dataloader_from_calib_fn(run_fn, run_args)
199-
200-
rounder = AutoRound(
201-
model=model,
202-
tokenizer=None,
203-
bits=4,
204-
group_size=128,
205-
sym=False,
206-
weight_config=weight_config,
207-
enable_full_range=enable_full_range, ##for symmetric, TODO support later
208-
batch_size=batch_size,
209-
amp=amp,
210-
device=device,
211-
lr_scheduler=lr_scheduler,
212-
dataloader=dataloader,
213-
use_quant_input=use_quant_input,
214-
enable_minmax_tuning=enable_minmax_tuning,
215-
lr=lr,
216-
minmax_lr=minmax_lr,
217-
low_gpu_mem_usage=low_gpu_mem_usage,
218-
iters=iters,
219-
seqlen=seqlen,
220-
n_samples=n_samples,
221-
sampler=sampler,
222-
seed=seed,
223-
n_blocks=n_blocks,
224-
gradient_accumulate_steps=gradient_accumulate_steps,
225-
not_use_best_mse=not_use_best_mse,
226-
dynamic_max_gap=dynamic_max_gap,
227-
data_type="int",
228-
scale_dtype=scale_dtype,
229-
run_fn=run_fn,
230-
run_args=run_args,
231-
)
232-
qdq_model, weight_config = rounder.quantize()
233-
return qdq_model, weight_config
320+
logger.info(summary_info)
321+
if len(unquantized_layers) > 0:
322+
logger.info(f"Summary: {unquantized_layers} have not been quantized")
323+
324+
self.quantized = True
325+
self.model = self.model.to(self.model_orig_dtype)
326+
return self.model, self.weight_config

0 commit comments

Comments
 (0)