Skip to content

Commit a3bf172

Browse files
Transition to a newer NNCF API for PyTorch model quantization (#630)
* Replace create_compressed_model call with quantize call for PyTorch backend * Remove ov_config from torch quantization arguments * Tweak test
1 parent 1f0ab3a commit a3bf172

File tree

2 files changed

+22
-43
lines changed

2 files changed

+22
-43
lines changed

optimum/intel/openvino/quantization.py

+22-34
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
import openvino
2525
import torch
2626
import transformers
27-
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
27+
from nncf import CompressWeightsMode, IgnoredScope, SensitivityMetric
2828
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
29-
from nncf.torch import create_compressed_model, register_default_init_args, register_module
30-
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
29+
from nncf.torch import register_module
3130
from nncf.torch.initialization import PTInitializingDataLoader
3231
from openvino._offline_transformations import compress_quantize_weights_transformation
3332
from openvino.runtime import Core, Tensor
@@ -47,7 +46,7 @@
4746
from ..utils.constant import _TASK_ALIASES
4847
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
4948
from ..utils.modeling_utils import get_model_device
50-
from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig, OVWeightQuantizationConfig
49+
from .configuration import OVConfig, OVWeightQuantizationConfig
5150
from .modeling_base import OVBaseModel
5251
from .utils import (
5352
MAX_ONNX_OPSET,
@@ -240,8 +239,6 @@ def quantize(
240239
if ov_config is not None:
241240
if not isinstance(ov_config, OVConfig):
242241
raise TypeError(f"`ov_config` should be an `OVConfig`, but got: {type(ov_config)} instead.")
243-
elif ov_config.compression is None:
244-
ov_config.compression = DEFAULT_QUANTIZATION_CONFIG
245242

246243
if isinstance(self.model, OVBaseModel):
247244
self._quantize_ovbasemodel(
@@ -263,7 +260,6 @@ def quantize(
263260
self._quantize_torchmodel(
264261
calibration_dataset,
265262
save_directory,
266-
ov_config,
267263
file_name,
268264
batch_size,
269265
data_collator,
@@ -319,7 +315,7 @@ def _quantize_ovbasemodel(
319315
calibration_dataloader = data_cache
320316

321317
# Actual model quantization
322-
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)
318+
quantization_dataset = nncf.Dataset(calibration_dataloader)
323319
quantized_model = nncf.quantize(
324320
self.model.model,
325321
quantization_dataset,
@@ -334,12 +330,13 @@ def _quantize_torchmodel(
334330
self,
335331
calibration_dataset: "Dataset",
336332
save_directory: Union[str, Path],
337-
ov_config: OVConfig = None,
338333
file_name: Optional[str] = None,
339334
batch_size: int = 1,
340335
data_collator: Optional[DataCollator] = None,
341336
remove_unused_columns: bool = True,
342337
weights_only: bool = False,
338+
save_onnx_model: bool = False,
339+
**kwargs,
343340
):
344341
self._set_task()
345342
save_directory = Path(save_directory)
@@ -356,15 +353,8 @@ def _quantize_torchmodel(
356353
model_type=model_type,
357354
)
358355

359-
if ov_config is None:
360-
logger.info(
361-
"No configuration describing the quantization process was provided, a default OVConfig will be generated."
362-
)
363-
ov_config = OVConfig(compression=DEFAULT_QUANTIZATION_CONFIG)
364356
onnx_file_name = (
365-
ONNX_WEIGHTS_NAME
366-
if file_name is None and ov_config.save_onnx_model
367-
else Path(ov_file_name).with_suffix(".onnx")
357+
ONNX_WEIGHTS_NAME if file_name is None and save_onnx_model else Path(ov_file_name).with_suffix(".onnx")
368358
)
369359

370360
task = self.task
@@ -398,7 +388,7 @@ def _quantize_torchmodel(
398388
if stateful:
399389
logger.warn(
400390
"Quantization algorithm does not support optimized stateful models. "
401-
"The original model without optimization will be quantized and export."
391+
"The original model without optimization will be quantized and exported."
402392
)
403393
stateful = False
404394

@@ -409,40 +399,38 @@ def _quantize_torchmodel(
409399
data_collator=data_collator,
410400
)
411401

412-
model_inputs = next(iter(calibration_dataloader))
413-
ov_config.add_input_info(model_inputs)
414-
nncf_config = NNCFConfig.from_dict(ov_config.__dict__)
415-
nncf_config = register_default_init_args(nncf_config, calibration_dataloader)
416-
controller, model = create_compressed_model(
417-
model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk
402+
quantization_dataset = nncf.Dataset(calibration_dataloader)
403+
model = nncf.quantize(
404+
model,
405+
quantization_dataset,
406+
model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"),
407+
fast_bias_correction=kwargs.get("fast_bias_correction", True),
408+
**kwargs,
418409
)
419-
model = controller.strip(do_copy=False)
420410

421-
model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name)
411+
model_path = save_directory / (onnx_file_name if save_onnx_model else ov_file_name)
422412
onnx_path = save_directory / onnx_file_name
423-
export_fn = export if not ov_config.save_onnx_model else export_pytorch_via_onnx
413+
export_fn = export if not save_onnx_model else export_pytorch_via_onnx
424414
opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET)
425415
opset = max(opset, MIN_ONNX_QDQ_OPSET)
426-
kwargs = {}
427-
if not ov_config.save_onnx_model:
428-
kwargs = {"stateful": stateful}
416+
export_kwargs = {}
417+
if not save_onnx_model:
418+
export_kwargs = {"stateful": stateful}
429419

430-
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs)
420+
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **export_kwargs)
431421
if is_onnx:
432422
# Load and save the compressed model
433423
model = core.read_model(onnx_path)
434424
# Model required second saving for appling weights compression transformations
435425
self._save_pretrained(model, output_path)
436426
# if onnx conversion happens as fallback for pytorch conversion, remove onnx model
437-
if not ov_config.save_onnx_model:
427+
if not save_onnx_model:
438428
os.remove(onnx_path)
439429
try:
440430
os.remove(f"{onnx_path}_data")
441431
except FileNotFoundError:
442432
pass
443433

444-
ov_config.save_pretrained(save_directory)
445-
446434
@staticmethod
447435
def _save_pretrained(model: openvino.runtime.Model, output_path: str):
448436
compress_quantize_weights_transformation(model)

tests/openvino/test_quantization.py

-9
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
OVWeightQuantizationConfig,
5858
)
5959

60-
from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG, DEFAULT_QUANTIZATION_CONFIG
6160
from optimum.intel.openvino.quantization import InferRequestWrapper
6261
from optimum.intel.utils.import_utils import is_openvino_version
6362
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8
@@ -110,10 +109,6 @@ def preprocess_function(examples, tokenizer):
110109
outputs = model(**tokens)
111110
self.assertTrue("logits" in outputs)
112111

113-
# Verify that that the configuration is correctly saved and loaded
114-
loaded_config = OVConfig.from_pretrained(tmp_dir)
115-
self.assertEqual(DEFAULT_QUANTIZATION_CONFIG, loaded_config.to_dict()["compression"])
116-
117112
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
118113
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
119114
task = model_cls.export_feature
@@ -265,10 +260,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
265260
outputs = model(**tokens)
266261
self.assertTrue("logits" in outputs)
267262

268-
# Verify that that the configuration is correctly saved and loaded
269-
loaded_config = OVConfig.from_pretrained(tmp_dir)
270-
self.assertIsNotNone(loaded_config)
271-
272263
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
273264
def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
274265
task = model_cls.export_feature

0 commit comments

Comments
 (0)