From 25bbc2b2ba8f515aa19b092344cd58ee3e79ae52 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 22 Mar 2024 17:05:00 +0100 Subject: [PATCH 1/3] Replace create_compressed_model call with quantize call for PyTorch backend --- optimum/intel/openvino/quantization.py | 36 ++++++++++++-------------- tests/openvino/test_quantization.py | 5 ---- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 2022a495d8..261516952e 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -24,10 +24,9 @@ import openvino import torch import transformers -from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric +from nncf import CompressWeightsMode, IgnoredScope, SensitivityMetric from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters -from nncf.torch import create_compressed_model, register_default_init_args, register_module -from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk +from nncf.torch import register_module from nncf.torch.initialization import PTInitializingDataLoader from openvino._offline_transformations import compress_quantize_weights_transformation from openvino.runtime import Core, Tensor @@ -47,7 +46,7 @@ from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available from ..utils.modeling_utils import get_model_device -from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig, OVWeightQuantizationConfig +from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel from .utils import ( MAX_ONNX_OPSET, @@ -240,8 +239,6 @@ def quantize( if ov_config is not None: if not isinstance(ov_config, OVConfig): raise TypeError(f"`ov_config` should be an `OVConfig`, but got: {type(ov_config)} instead.") - elif ov_config.compression is None: - ov_config.compression = DEFAULT_QUANTIZATION_CONFIG if isinstance(self.model, OVBaseModel): self._quantize_ovbasemodel( @@ -319,7 +316,7 @@ def _quantize_ovbasemodel( calibration_dataloader = data_cache # Actual model quantization - quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x) + quantization_dataset = nncf.Dataset(calibration_dataloader) quantized_model = nncf.quantize( self.model.model, quantization_dataset, @@ -340,6 +337,7 @@ def _quantize_torchmodel( data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, weights_only: bool = False, + **kwargs, ): self._set_task() save_directory = Path(save_directory) @@ -360,7 +358,7 @@ def _quantize_torchmodel( logger.info( "No configuration describing the quantization process was provided, a default OVConfig will be generated." ) - ov_config = OVConfig(compression=DEFAULT_QUANTIZATION_CONFIG) + ov_config = OVConfig() onnx_file_name = ( ONNX_WEIGHTS_NAME if file_name is None and ov_config.save_onnx_model @@ -398,7 +396,7 @@ def _quantize_torchmodel( if stateful: logger.warn( "Quantization algorithm does not support optimized stateful models. " - "The original model without optimization will be quantized and export." + "The original model without optimization will be quantized and exported." ) stateful = False @@ -409,25 +407,25 @@ def _quantize_torchmodel( data_collator=data_collator, ) - model_inputs = next(iter(calibration_dataloader)) - ov_config.add_input_info(model_inputs) - nncf_config = NNCFConfig.from_dict(ov_config.__dict__) - nncf_config = register_default_init_args(nncf_config, calibration_dataloader) - controller, model = create_compressed_model( - model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk + quantization_dataset = nncf.Dataset(calibration_dataloader) + model = nncf.quantize( + model, + quantization_dataset, + model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), + fast_bias_correction=kwargs.get("fast_bias_correction", True), + **kwargs, ) - model = controller.strip(do_copy=False) model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name) onnx_path = save_directory / onnx_file_name export_fn = export if not ov_config.save_onnx_model else export_pytorch_via_onnx opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET) opset = max(opset, MIN_ONNX_QDQ_OPSET) - kwargs = {} + export_kwargs = {} if not ov_config.save_onnx_model: - kwargs = {"stateful": stateful} + export_kwargs = {"stateful": stateful} - _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs) + _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **export_kwargs) if is_onnx: # Load and save the compressed model model = core.read_model(onnx_path) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index c7fb00e12d..063d1e214f 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -57,7 +57,6 @@ OVWeightQuantizationConfig, ) -from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG, DEFAULT_QUANTIZATION_CONFIG from optimum.intel.openvino.quantization import InferRequestWrapper from optimum.intel.utils.import_utils import is_openvino_version from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8 @@ -110,10 +109,6 @@ def preprocess_function(examples, tokenizer): outputs = model(**tokens) self.assertTrue("logits" in outputs) - # Verify that that the configuration is correctly saved and loaded - loaded_config = OVConfig.from_pretrained(tmp_dir) - self.assertEqual(DEFAULT_QUANTIZATION_CONFIG, loaded_config.to_dict()["compression"]) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): task = model_cls.export_feature From 5544e61e3ac5c6c187cb7d61b3c6e75c4a83b86f Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 25 Mar 2024 20:13:16 +0100 Subject: [PATCH 2/3] Remove ov_config from torch quantization arguments --- optimum/intel/openvino/quantization.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 261516952e..3a40ff4f3e 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -260,7 +260,6 @@ def quantize( self._quantize_torchmodel( calibration_dataset, save_directory, - ov_config, file_name, batch_size, data_collator, @@ -331,12 +330,12 @@ def _quantize_torchmodel( self, calibration_dataset: "Dataset", save_directory: Union[str, Path], - ov_config: OVConfig = None, file_name: Optional[str] = None, batch_size: int = 1, data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, weights_only: bool = False, + save_onnx_model: bool = False, **kwargs, ): self._set_task() @@ -354,15 +353,8 @@ def _quantize_torchmodel( model_type=model_type, ) - if ov_config is None: - logger.info( - "No configuration describing the quantization process was provided, a default OVConfig will be generated." - ) - ov_config = OVConfig() onnx_file_name = ( - ONNX_WEIGHTS_NAME - if file_name is None and ov_config.save_onnx_model - else Path(ov_file_name).with_suffix(".onnx") + ONNX_WEIGHTS_NAME if file_name is None and save_onnx_model else Path(ov_file_name).with_suffix(".onnx") ) task = self.task @@ -416,13 +408,13 @@ def _quantize_torchmodel( **kwargs, ) - model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name) + model_path = save_directory / (onnx_file_name if save_onnx_model else ov_file_name) onnx_path = save_directory / onnx_file_name - export_fn = export if not ov_config.save_onnx_model else export_pytorch_via_onnx + export_fn = export if not save_onnx_model else export_pytorch_via_onnx opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET) opset = max(opset, MIN_ONNX_QDQ_OPSET) export_kwargs = {} - if not ov_config.save_onnx_model: + if not save_onnx_model: export_kwargs = {"stateful": stateful} _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **export_kwargs) @@ -432,15 +424,13 @@ def _quantize_torchmodel( # Model required second saving for appling weights compression transformations self._save_pretrained(model, output_path) # if onnx conversion happens as fallback for pytorch conversion, remove onnx model - if not ov_config.save_onnx_model: + if not save_onnx_model: os.remove(onnx_path) try: os.remove(f"{onnx_path}_data") except FileNotFoundError: pass - ov_config.save_pretrained(save_directory) - @staticmethod def _save_pretrained(model: openvino.runtime.Model, output_path: str): compress_quantize_weights_transformation(model) From ae04ab627fa3492361ffeaf2b64b0bab8bcaaa5a Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 26 Mar 2024 10:17:07 +0100 Subject: [PATCH 3/3] Tweak test --- tests/openvino/test_quantization.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 063d1e214f..8c166a5e8c 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -260,10 +260,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i outputs = model(**tokens) self.assertTrue("logits" in outputs) - # Verify that that the configuration is correctly saved and loaded - loaded_config = OVConfig.from_pretrained(tmp_dir) - self.assertIsNotNone(loaded_config) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature