Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transition to a newer NNCF API for PyTorch model quantization #630

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 22 additions & 34 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
import openvino
import torch
import transformers
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
from nncf import CompressWeightsMode, IgnoredScope, SensitivityMetric
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.torch import create_compressed_model, register_default_init_args, register_module
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch import register_module
from nncf.torch.initialization import PTInitializingDataLoader
from openvino._offline_transformations import compress_quantize_weights_transformation
from openvino.runtime import Core, Tensor
Expand All @@ -47,7 +46,7 @@
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
from ..utils.modeling_utils import get_model_device
from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig, OVWeightQuantizationConfig
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel
from .utils import (
MAX_ONNX_OPSET,
Expand Down Expand Up @@ -240,8 +239,6 @@ def quantize(
if ov_config is not None:
if not isinstance(ov_config, OVConfig):
raise TypeError(f"`ov_config` should be an `OVConfig`, but got: {type(ov_config)} instead.")
elif ov_config.compression is None:
ov_config.compression = DEFAULT_QUANTIZATION_CONFIG

if isinstance(self.model, OVBaseModel):
self._quantize_ovbasemodel(
Expand All @@ -263,7 +260,6 @@ def quantize(
self._quantize_torchmodel(
calibration_dataset,
save_directory,
ov_config,
file_name,
batch_size,
data_collator,
Expand Down Expand Up @@ -319,7 +315,7 @@ def _quantize_ovbasemodel(
calibration_dataloader = data_cache

# Actual model quantization
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)
quantization_dataset = nncf.Dataset(calibration_dataloader)
quantized_model = nncf.quantize(
self.model.model,
quantization_dataset,
Expand All @@ -334,12 +330,13 @@ def _quantize_torchmodel(
self,
calibration_dataset: "Dataset",
save_directory: Union[str, Path],
ov_config: OVConfig = None,
file_name: Optional[str] = None,
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
weights_only: bool = False,
save_onnx_model: bool = False,
**kwargs,
):
self._set_task()
save_directory = Path(save_directory)
Expand All @@ -356,15 +353,8 @@ def _quantize_torchmodel(
model_type=model_type,
)

if ov_config is None:
logger.info(
"No configuration describing the quantization process was provided, a default OVConfig will be generated."
)
ov_config = OVConfig(compression=DEFAULT_QUANTIZATION_CONFIG)
onnx_file_name = (
ONNX_WEIGHTS_NAME
if file_name is None and ov_config.save_onnx_model
else Path(ov_file_name).with_suffix(".onnx")
ONNX_WEIGHTS_NAME if file_name is None and save_onnx_model else Path(ov_file_name).with_suffix(".onnx")
)

task = self.task
Expand Down Expand Up @@ -398,7 +388,7 @@ def _quantize_torchmodel(
if stateful:
logger.warn(
"Quantization algorithm does not support optimized stateful models. "
"The original model without optimization will be quantized and export."
"The original model without optimization will be quantized and exported."
)
stateful = False

Expand All @@ -409,40 +399,38 @@ def _quantize_torchmodel(
data_collator=data_collator,
)

model_inputs = next(iter(calibration_dataloader))
ov_config.add_input_info(model_inputs)
nncf_config = NNCFConfig.from_dict(ov_config.__dict__)
nncf_config = register_default_init_args(nncf_config, calibration_dataloader)
controller, model = create_compressed_model(
model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk
quantization_dataset = nncf.Dataset(calibration_dataloader)
model = nncf.quantize(
model,
quantization_dataset,
model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"),
fast_bias_correction=kwargs.get("fast_bias_correction", True),
**kwargs,
)
model = controller.strip(do_copy=False)

model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name)
model_path = save_directory / (onnx_file_name if save_onnx_model else ov_file_name)
onnx_path = save_directory / onnx_file_name
export_fn = export if not ov_config.save_onnx_model else export_pytorch_via_onnx
export_fn = export if not save_onnx_model else export_pytorch_via_onnx
opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET)
opset = max(opset, MIN_ONNX_QDQ_OPSET)
kwargs = {}
if not ov_config.save_onnx_model:
kwargs = {"stateful": stateful}
export_kwargs = {}
if not save_onnx_model:
export_kwargs = {"stateful": stateful}

_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs)
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **export_kwargs)
if is_onnx:
# Load and save the compressed model
model = core.read_model(onnx_path)
# Model required second saving for appling weights compression transformations
self._save_pretrained(model, output_path)
# if onnx conversion happens as fallback for pytorch conversion, remove onnx model
if not ov_config.save_onnx_model:
if not save_onnx_model:
os.remove(onnx_path)
try:
os.remove(f"{onnx_path}_data")
except FileNotFoundError:
pass

ov_config.save_pretrained(save_directory)

@staticmethod
def _save_pretrained(model: openvino.runtime.Model, output_path: str):
compress_quantize_weights_transformation(model)
Expand Down
9 changes: 0 additions & 9 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
OVWeightQuantizationConfig,
)

from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG, DEFAULT_QUANTIZATION_CONFIG
from optimum.intel.openvino.quantization import InferRequestWrapper
from optimum.intel.utils.import_utils import is_openvino_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8
Expand Down Expand Up @@ -110,10 +109,6 @@ def preprocess_function(examples, tokenizer):
outputs = model(**tokens)
self.assertTrue("logits" in outputs)

# Verify that that the configuration is correctly saved and loaded
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertEqual(DEFAULT_QUANTIZATION_CONFIG, loaded_config.to_dict()["compression"])

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
task = model_cls.export_feature
Expand Down Expand Up @@ -265,10 +260,6 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
outputs = model(**tokens)
self.assertTrue("logits" in outputs)

# Verify that that the configuration is correctly saved and loaded
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertIsNotNone(loaded_config)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature
Expand Down
Loading