24
24
import openvino
25
25
import torch
26
26
import transformers
27
- from nncf import CompressWeightsMode , IgnoredScope , NNCFConfig , SensitivityMetric
27
+ from nncf import CompressWeightsMode , IgnoredScope , SensitivityMetric
28
28
from nncf .quantization .advanced_parameters import AdvancedSmoothQuantParameters
29
- from nncf .torch import create_compressed_model , register_default_init_args , register_module
30
- from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
29
+ from nncf .torch import register_module
31
30
from nncf .torch .initialization import PTInitializingDataLoader
32
31
from openvino ._offline_transformations import compress_quantize_weights_transformation
33
32
from openvino .runtime import Core , Tensor
47
46
from ..utils .constant import _TASK_ALIASES
48
47
from ..utils .import_utils import DATASETS_IMPORT_ERROR , is_datasets_available
49
48
from ..utils .modeling_utils import get_model_device
50
- from .configuration import DEFAULT_QUANTIZATION_CONFIG , OVConfig , OVWeightQuantizationConfig
49
+ from .configuration import OVConfig , OVWeightQuantizationConfig
51
50
from .modeling_base import OVBaseModel
52
51
from .utils import (
53
52
MAX_ONNX_OPSET ,
@@ -240,8 +239,6 @@ def quantize(
240
239
if ov_config is not None :
241
240
if not isinstance (ov_config , OVConfig ):
242
241
raise TypeError (f"`ov_config` should be an `OVConfig`, but got: { type (ov_config )} instead." )
243
- elif ov_config .compression is None :
244
- ov_config .compression = DEFAULT_QUANTIZATION_CONFIG
245
242
246
243
if isinstance (self .model , OVBaseModel ):
247
244
self ._quantize_ovbasemodel (
@@ -319,7 +316,7 @@ def _quantize_ovbasemodel(
319
316
calibration_dataloader = data_cache
320
317
321
318
# Actual model quantization
322
- quantization_dataset = nncf .Dataset (calibration_dataloader , lambda x : x )
319
+ quantization_dataset = nncf .Dataset (calibration_dataloader )
323
320
quantized_model = nncf .quantize (
324
321
self .model .model ,
325
322
quantization_dataset ,
@@ -340,6 +337,7 @@ def _quantize_torchmodel(
340
337
data_collator : Optional [DataCollator ] = None ,
341
338
remove_unused_columns : bool = True ,
342
339
weights_only : bool = False ,
340
+ ** kwargs ,
343
341
):
344
342
self ._set_task ()
345
343
save_directory = Path (save_directory )
@@ -360,7 +358,7 @@ def _quantize_torchmodel(
360
358
logger .info (
361
359
"No configuration describing the quantization process was provided, a default OVConfig will be generated."
362
360
)
363
- ov_config = OVConfig (compression = DEFAULT_QUANTIZATION_CONFIG )
361
+ ov_config = OVConfig ()
364
362
onnx_file_name = (
365
363
ONNX_WEIGHTS_NAME
366
364
if file_name is None and ov_config .save_onnx_model
@@ -398,7 +396,7 @@ def _quantize_torchmodel(
398
396
if stateful :
399
397
logger .warn (
400
398
"Quantization algorithm does not support optimized stateful models. "
401
- "The original model without optimization will be quantized and export ."
399
+ "The original model without optimization will be quantized and exported ."
402
400
)
403
401
stateful = False
404
402
@@ -409,25 +407,25 @@ def _quantize_torchmodel(
409
407
data_collator = data_collator ,
410
408
)
411
409
412
- model_inputs = next (iter (calibration_dataloader ))
413
- ov_config .add_input_info (model_inputs )
414
- nncf_config = NNCFConfig .from_dict (ov_config .__dict__ )
415
- nncf_config = register_default_init_args (nncf_config , calibration_dataloader )
416
- controller , model = create_compressed_model (
417
- model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
410
+ quantization_dataset = nncf .Dataset (calibration_dataloader )
411
+ model = nncf .quantize (
412
+ model ,
413
+ quantization_dataset ,
414
+ model_type = nncf .ModelType .TRANSFORMER if not kwargs .get ("model_type" ) else kwargs .get ("model_type" ),
415
+ fast_bias_correction = kwargs .get ("fast_bias_correction" , True ),
416
+ ** kwargs ,
418
417
)
419
- model = controller .strip (do_copy = False )
420
418
421
419
model_path = save_directory / (onnx_file_name if ov_config .save_onnx_model else ov_file_name )
422
420
onnx_path = save_directory / onnx_file_name
423
421
export_fn = export if not ov_config .save_onnx_model else export_pytorch_via_onnx
424
422
opset = min (onnx_config .DEFAULT_ONNX_OPSET , MAX_ONNX_OPSET )
425
423
opset = max (opset , MIN_ONNX_QDQ_OPSET )
426
- kwargs = {}
424
+ export_kwargs = {}
427
425
if not ov_config .save_onnx_model :
428
- kwargs = {"stateful" : stateful }
426
+ export_kwargs = {"stateful" : stateful }
429
427
430
- _ , _ , is_onnx = export_fn (model = model , config = onnx_config , output = model_path , opset = opset , ** kwargs )
428
+ _ , _ , is_onnx = export_fn (model = model , config = onnx_config , output = model_path , opset = opset , ** export_kwargs )
431
429
if is_onnx :
432
430
# Load and save the compressed model
433
431
model = core .read_model (onnx_path )
0 commit comments