forked from huggingface/optimum-intel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquantization.py
577 lines (503 loc) · 24 KB
/
quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import nncf
import openvino
import torch
import transformers
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
from nncf.torch import create_compressed_model, register_default_init_args, register_module
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch.initialization import PTInitializingDataLoader
from openvino._offline_transformations import compress_quantize_weights_transformation
from openvino.runtime import Core, Tensor
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator
from transformers.pytorch_utils import Conv1D
from transformers.utils import is_accelerate_available
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.tasks import TasksManager
from optimum.quantization_base import OptimumQuantizer
from ...exporters.openvino import export, export_pytorch_via_onnx
from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer
from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
from ..utils.modeling_utils import get_model_device
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel
from .utils import (
MAX_ONNX_OPSET,
MIN_ONNX_QDQ_OPSET,
ONNX_WEIGHTS_NAME,
OV_XML_FILE_NAME,
)
if is_datasets_available():
if TYPE_CHECKING:
from datasets import Dataset
register_module(ignored_algorithms=[])(Conv1D)
core = Core()
logger = logging.getLogger(__name__)
class OVDataLoader(PTInitializingDataLoader):
def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
return (), dataloader_output
@property
def batch_size(self):
batch_size = self._data_loader.batch_size
if is_accelerate_available():
from accelerate.data_loader import DataLoaderStateMixin
if batch_size is None and isinstance(self._data_loader, DataLoaderStateMixin):
batch_size = self._data_loader.total_batch_size
return batch_size
class InferRequestWrapper:
def __init__(self, request, data_cache=None):
self.request = request
if data_cache is None:
data_cache = []
self.data_cache = data_cache
def __call__(self, *args, **kwargs):
self.data_cache.append(*args)
return self.request(*args, **kwargs)
def infer(self, inputs: Any = None, share_inputs: bool = False):
self.data_cache.append(inputs)
return self.request.infer(inputs, share_inputs)
def start_async(
self,
inputs: Any = None,
userdata: Any = None,
share_inputs: bool = False,
*,
shared_memory: Any = None,
):
self.data_cache.append(inputs)
self.request.infer(inputs, share_inputs, share_outputs=True)
def wait(self):
pass
def get_tensor(self, name: str):
return Tensor(self.request.results[name])
def __getattr__(self, attr):
if attr in self.__dict__:
return getattr(self, attr)
return getattr(self.request, attr)
class OVQuantizer(OptimumQuantizer):
"""
Handle the NNCF quantization process.
"""
def __init__(self, model: transformers.PreTrainedModel, task: Optional[str] = None, seed: int = 42, **kwargs):
"""
Args:
model (`transformers.PreTrainedModel`):
The [PreTrainedModel](https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel) to quantize.
task (`str`, defaults to None):
The task defining the model topology used for the ONNX export.
seed (`int`, defaults to 42):
The random seed to use when shuffling the calibration dataset.
"""
super().__init__()
self.model = model
feature = kwargs.pop("feature", None)
if feature is not None:
logger.warning("`feature` is deprecated and will be removed in a future version. Use `task` instead.")
if task is not None and task != feature:
logger.warning(
f"Both `feature` and `task` were specified. {task} will be used to define the model topology for the model ONNX export."
)
self.task = task or feature
self.seed = seed
self.input_names = None
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
self._export_input_names = [
column for column in self._signature_columns if column not in {"label", "labels", "label_ids"}
]
@classmethod
def from_pretrained(cls, model: PreTrainedModel, **kwargs):
# TODO : Create model
return cls(model, **kwargs)
def quantize(
self,
calibration_dataset: "Dataset" = None,
save_directory: Union[str, Path] = None,
ov_config: OVConfig = None,
file_name: Optional[str] = None,
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
**kwargs,
):
"""
Quantize a model given the optimization specifications defined in `quantization_config`.
Args:
calibration_dataset (`datasets.Dataset`):
The dataset to use for the calibration step.
save_directory (`Union[str, Path]`):
The directory where the quantized model should be saved.
quantization_config (`OVConfig`, *optional*):
The configuration containing the parameters related to quantization.
file_name (`str`, *optional*):
The model file name to use when saving the model. Overwrites the default file name `"model.onnx"`.
batch_size (`int`, defaults to 8):
The number of calibration samples to load per batch.
data_collator (`DataCollator`, *optional*):
The function to use to form a batch from a list of elements of the calibration dataset.
remove_unused_columns (`bool`, defaults to `True`):
Whether or not to remove the columns unused by the model forward method.
weights_only (`bool`, defaults to `False`):
Compress weights to integer precision (8-bit by default) while keeping activations
floating-point. Fits best for LLM footprint reduction and performance acceleration.
Examples:
```python
>>> from optimum.intel.openvino import OVQuantizer, OVModelForSequenceClassification
>>> from transformers import AutoModelForSequenceClassification
>>> model = OVModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True)
>>> # or
>>> model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>>> quantizer = OVQuantizer.from_pretrained(model, task="text-classification")
>>> quantizer.quantize(calibration_dataset=calibration_dataset, save_directory="./quantized_model")
>>> optimized_model = OVModelForSequenceClassification.from_pretrained("./quantized_model")
```
```python
>>> from optimum.intel.openvino import OVQuantizer, OVModelForCausalLM
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b")
>>> quantizer = OVQuantizer.from_pretrained(model, task="text-generation")
>>> quantizer.quantize(save_directory="./quantized_model", weights_only=True)
>>> optimized_model = OVModelForCausalLM.from_pretrained("./quantized_model")
```
"""
if save_directory is None:
# TODO : can be set to self.model.config.name_or_path for OVModels when not provided
raise ValueError("`save_directory` needs to be specified")
if weights_only:
if calibration_dataset is not None:
logger.warning(
"`calibration_dataset` was provided but will not be used as `weights_only` is set to `True`."
)
else:
if calibration_dataset is None:
raise ValueError(
"`calibration_dataset` is needed to compute the activations range during the calibration step and was not provided. "
"In case you only want to apply quantization on the weights, please set `weights_only=True`."
)
quantization_config = kwargs.pop("quantization_config", None)
if quantization_config is not None:
logger.warning(
"The argument `quantization_config` is deprecated, and will be removed in optimum-intel v1.6.0, please use `ov_config` instead"
)
ov_config = ov_config or quantization_config
if isinstance(self.model, OVBaseModel):
self._quantize_ovbasemodel(
calibration_dataset,
save_directory,
batch_size,
data_collator,
remove_unused_columns,
weights_only,
ov_config,
**kwargs,
)
elif isinstance(self.model, torch.nn.Module):
self._quantize_torchmodel(
calibration_dataset,
save_directory,
ov_config,
file_name,
batch_size,
data_collator,
remove_unused_columns,
weights_only,
)
else:
raise TypeError(f"Unsupported model type: {type(self.model)}")
def _quantize_ovbasemodel(
self,
calibration_dataset: "Dataset",
save_directory: Union[str, Path],
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
ov_config: OVConfig = None,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
if weights_only:
# Use default 8-bit compression if not provided
q_config = (
OVWeightQuantizationConfig(bits=8, sym=True) if ov_config is None else ov_config.quantization_config
)
_weight_only_quantization(self.model, q_config)
self.model.save_pretrained(save_directory)
return
calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
)
if self.model.export_feature == "text-generation" and self.model.use_cache:
# Prefeth past_key_values
self.model.update_pkv_precision(True)
self.model.compile()
subset_size = kwargs.get("subset_size", 300)
data_cache = []
self.model.request = InferRequestWrapper(self.model.request, data_cache)
for _, data in enumerate(calibration_dataloader):
self.model.generate(**data, max_new_tokens=1)
if len(data_cache) >= subset_size:
break
self.model.request = self.model.request.request
calibration_dataloader = data_cache
# Actual model quantization
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)
quantized_model = nncf.quantize(
self.model.model,
quantization_dataset,
model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"),
fast_bias_correction=kwargs.get("fast_bias_correction", True),
**kwargs,
)
self.model.model = quantized_model
self.model.save_pretrained(save_directory)
def _quantize_torchmodel(
self,
calibration_dataset: "Dataset",
save_directory: Union[str, Path],
ov_config: OVConfig = None,
file_name: Optional[str] = None,
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
):
self._set_task()
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
ov_file_name = file_name if file_name is not None else OV_XML_FILE_NAME
output_path = save_directory.joinpath(ov_file_name)
output_path = output_path.with_suffix(".xml").as_posix()
model_type = self.model.config.model_type.replace("_", "-")
onnx_config_class = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model=self.model,
task=self.task,
model_type=model_type,
)
if ov_config is None:
logger.info(
"No configuration describing the quantization process was provided, a default OVConfig will be generated."
)
ov_config = OVConfig()
onnx_file_name = (
ONNX_WEIGHTS_NAME
if file_name is None and ov_config.save_onnx_model
else Path(ov_file_name).with_suffix(".onnx")
)
task = self.task
model = self.model
self.model.config.save_pretrained(save_directory)
if task.startswith("text-generation"):
onnx_config = onnx_config_class(
model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache
)
if model.config.use_cache:
task = "text-generation-with-past"
else:
onnx_config = onnx_config_class(model.config)
stateful = ensure_stateful_is_available() and ensure_export_task_support_stateful(task)
if weights_only:
if stateful:
# patch model before weight compression
model = patch_model_with_bettertransformer(model)
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
device = get_model_device(model)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
)
check_dummy_inputs_are_allowed(model, dummy_inputs)
nncf.compress_weights(model, dataset=nncf.Dataset([dummy_inputs]))
else:
if stateful:
logger.warn(
"Quantization algorithm does not support optimized stateful models. "
"The original model without optimization will be quantized and export."
)
stateful = False
calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
)
model_inputs = next(iter(calibration_dataloader))
ov_config.add_input_info(model_inputs)
nncf_config = NNCFConfig.from_dict(ov_config.__dict__)
nncf_config = register_default_init_args(nncf_config, calibration_dataloader)
controller, model = create_compressed_model(
model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk
)
model = controller.strip(do_copy=False)
model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name)
onnx_path = save_directory / onnx_file_name
export_fn = export if not ov_config.save_onnx_model else export_pytorch_via_onnx
opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET)
opset = max(opset, MIN_ONNX_QDQ_OPSET)
kwargs = {}
if not ov_config.save_onnx_model:
kwargs = {"stateful": stateful}
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs)
if is_onnx:
# Load and save the compressed model
model = core.read_model(onnx_path)
# Model required second saving for appling weights compression transformations
self._save_pretrained(model, output_path)
# if onnx conversion happens as fallback for pytorch conversion, remove onnx model
if not ov_config.save_onnx_model:
os.remove(onnx_path)
try:
os.remove(f"{onnx_path}_data")
except FileNotFoundError:
pass
ov_config.save_pretrained(save_directory)
@staticmethod
def _save_pretrained(model: openvino.runtime.Model, output_path: str):
compress_quantize_weights_transformation(model)
openvino.save_model(model, output_path, compress_to_fp16=False)
def _set_task(self):
if self.task is None:
self.task = TasksManager.infer_task_from_model(self.model.config._name_or_path)
if self.task is None:
raise ValueError(
"The task defining the model topology could not be extracted and needs to be specified for the ONNX export."
)
self.task = _TASK_ALIASES.get(self.task, self.task)
if self.task == "text2text-generation":
raise ValueError("Seq2Seq models are currently not supported for post-training static quantization.")
if self.task == "image-to-text":
raise ValueError("Image2Text models are currently not supported for post-training static quantization.")
def get_calibration_dataset(
self,
dataset_name: str,
num_samples: int = 100,
dataset_config_name: Optional[str] = None,
dataset_split: str = "train",
preprocess_function: Optional[Callable] = None,
preprocess_batch: bool = True,
use_auth_token: bool = False,
cache_dir: Optional[str] = None,
) -> "Dataset":
"""
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
Args:
dataset_name (`str`):
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
in generic formats and optionally a dataset script, if it requires some code to read the data files.
num_samples (`int`, defaults to 100):
The maximum number of samples composing the calibration dataset.
dataset_config_name (`str`, *optional*):
The name of the dataset configuration.
dataset_split (`str`, defaults to `"train"`):
Which split of the dataset to use to perform the calibration step.
preprocess_function (`Callable`, *optional*):
Processing function to apply to each example after loading dataset.
preprocess_batch (`bool`, defaults to `True`):
Whether the `preprocess_function` should be batched.
use_auth_token (`bool`, defaults to `False`):
Whether to use the token generated when running `transformers-cli login`.
cache_dir (`str`, *optional*):
Caching directory for a calibration dataset.
Returns:
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
"""
if not is_datasets_available():
raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset"))
from datasets import load_dataset
calibration_dataset = load_dataset(
dataset_name,
name=dataset_config_name,
split=dataset_split,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
if num_samples is not None:
num_samples = min(num_samples, len(calibration_dataset))
calibration_dataset = calibration_dataset.shuffle(seed=self.seed).select(range(num_samples))
if preprocess_function is not None:
calibration_dataset = calibration_dataset.map(preprocess_function, batched=preprocess_batch)
return calibration_dataset
def _get_calibration_dataloader(
self,
calibration_dataset: "Dataset",
batch_size: int,
remove_unused_columns: bool,
data_collator: Optional[DataCollator] = None,
) -> OVDataLoader:
data_collator = data_collator if data_collator is not None else default_data_collator
if remove_unused_columns:
calibration_dataset = self._remove_unused_columns(calibration_dataset)
self.input_names = calibration_dataset.column_names
generator = torch.Generator()
generator.manual_seed(self.seed)
sampler = RandomSampler(calibration_dataset, generator=generator)
calibration_dataloader = DataLoader(
calibration_dataset, batch_size=batch_size, sampler=sampler, collate_fn=data_collator, drop_last=False
)
return OVDataLoader(calibration_dataloader)
def _remove_unused_columns(self, dataset: "Dataset"):
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
return dataset.remove_columns(ignored_columns)
def _weight_only_quantization(model: OVBaseModel, quantization_config: Union[OVWeightQuantizationConfig, Dict]):
ov_model = model.model
config = quantization_config
if isinstance(config, dict):
config = OVWeightQuantizationConfig.from_dict(quantization_config)
dataset = config.dataset
if config.dataset is not None and isinstance(config.dataset, str):
tokenizer = config.tokenizer
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
elif isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
from optimum.gptq.data import get_dataset, prepare_dataset
dataset = get_dataset(config.dataset, tokenizer, seqlen=32)
dataset = prepare_dataset(dataset)
dataset = nncf.Dataset(dataset, lambda x: model.prepare_inputs(**x))
sensitivity_metric = None
if isinstance(config.sensitivity_metric, str):
sensitivity_metric = getattr(SensitivityMetric, config.sensitivity_metric.upper())
ignored_scope = None
if isinstance(config.ignored_scope, dict):
ignored_scope = IgnoredScope(**config.ignored_scope)
if config.bits == 8:
mode = CompressWeightsMode.INT8_SYM if config.sym else CompressWeightsMode.INT8_ASYM
else:
mode = CompressWeightsMode.INT4_SYM if config.sym else CompressWeightsMode.INT4_ASYM
model.model = nncf.compress_weights(
ov_model,
mode=mode,
ratio=config.ratio,
group_size=config.group_size,
all_layers=config.all_layers,
sensitivity_metric=sensitivity_metric,
# awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0
ignored_scope=ignored_scope,
dataset=dataset,
)