Skip to content

Commit 277d39a

Browse files
committedFeb 6, 2024
Make quantization_config a part of OVConfig in OVQuantizer
1 parent de4d192 commit 277d39a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed
 

‎optimum/intel/openvino/configuration.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Dict, List, Optional, Union
1616

1717
import torch
18+
from transformers.utils.quantization_config import QuantizationConfigMixin
1819

1920
from optimum.configuration_utils import BaseConfig
2021

@@ -83,6 +84,7 @@ def __init__(
8384
compression: Union[List[Dict], Dict, None] = None,
8485
input_info: Optional[List] = None,
8586
save_onnx_model: bool = False,
87+
quantization_config: Optional[QuantizationConfigMixin] = None,
8688
**kwargs,
8789
):
8890
super().__init__()
@@ -91,6 +93,7 @@ def __init__(
9193
self.save_onnx_model = save_onnx_model
9294
self._enable_standard_onnx_export_option()
9395
self.optimum_version = kwargs.pop("optimum_version", None)
96+
self.quantization_config = quantization_config
9497

9598
def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
9699
self.input_info = [

‎optimum/intel/openvino/quantization.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from torch.utils.data import DataLoader, RandomSampler
3434
from transformers import DataCollator, PreTrainedModel, default_data_collator
3535
from transformers.pytorch_utils import Conv1D
36-
from transformers.utils.quantization_config import QuantizationConfigMixin
3736

3837
from optimum.exporters.tasks import TasksManager
3938
from optimum.quantization_base import OptimumQuantizer
@@ -159,7 +158,6 @@ def quantize(
159158
self,
160159
calibration_dataset: Dataset = None,
161160
save_directory: Union[str, Path] = None,
162-
quantization_config: QuantizationConfigMixin = None,
163161
ov_config: OVConfig = None,
164162
file_name: Optional[str] = None,
165163
batch_size: int = 1,
@@ -234,7 +232,7 @@ def quantize(
234232
data_collator,
235233
remove_unused_columns,
236234
weights_only,
237-
quantization_config,
235+
ov_config,
238236
**kwargs,
239237
)
240238
elif isinstance(self.model, OVBaseModel):
@@ -313,13 +311,14 @@ def _quantize_ovcausallm(
313311
data_collator: Optional[DataCollator] = None,
314312
remove_unused_columns: bool = True,
315313
weights_only: bool = False,
316-
quantization_config: QuantizationConfigMixin = None,
314+
ov_config: OVConfig = None,
317315
**kwargs,
318316
):
319317
save_directory = Path(save_directory)
320318
save_directory.mkdir(parents=True, exist_ok=True)
321319

322320
if weights_only:
321+
quantization_config = None if ov_config is None else ov_config.quantization_config
323322
if quantization_config is None:
324323
# Use default 8-bit compression
325324
self.model.model = nncf.compress_weights(self.model.model)

0 commit comments

Comments
 (0)