24
24
import transformers
25
25
from accelerate .data_loader import DataLoaderStateMixin
26
26
from datasets import Dataset , load_dataset
27
- from nncf import NNCFConfig , compress_weights
27
+ from nncf import NNCFConfig
28
28
from nncf .torch import create_compressed_model , register_default_init_args , register_module
29
29
from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
30
30
from nncf .torch .initialization import PTInitializingDataLoader
31
31
from openvino ._offline_transformations import compress_quantize_weights_transformation
32
32
from openvino .runtime import Core , Tensor
33
+ from torch .utils ._pytree import tree_map
33
34
from torch .utils .data import DataLoader , RandomSampler
34
35
from transformers import DataCollator , PreTrainedModel , default_data_collator
35
36
from transformers .pytorch_utils import Conv1D
36
37
38
+ from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
37
39
from optimum .exporters .tasks import TasksManager
38
40
from optimum .quantization_base import OptimumQuantizer
39
41
40
42
from ...exporters .openvino import export , export_pytorch_via_onnx
41
- from ...exporters .openvino .stateful import ensure_export_task_support_stateful
43
+ from ...exporters .openvino .model_patcher import patch_model_with_bettertransformer
44
+ from ...exporters .openvino .stateful import ensure_export_task_support_stateful , ensure_stateful_is_available
42
45
from ..utils .constant import _TASK_ALIASES
46
+ from ..utils .modeling_utils import get_model_device
43
47
from .configuration import OVConfig
44
48
from .modeling_base import OVBaseModel
45
49
from .modeling_decoder import OVBaseDecoderModel
@@ -361,9 +365,7 @@ def _quantize_ovcausallm(
361
365
self .model .model ,
362
366
quantization_dataset ,
363
367
model_type = nncf .ModelType .TRANSFORMER if not kwargs .get ("model_type" ) else kwargs .get ("model_type" ),
364
- fast_bias_correction = True
365
- if not kwargs .get ("fast_bias_correction" )
366
- else kwargs .get ("fast_bias_correction" ),
368
+ fast_bias_correction = kwargs .get ("fast_bias_correction" , True ),
367
369
** kwargs ,
368
370
)
369
371
self .model .model = quantized_model
@@ -405,13 +407,42 @@ def _quantize_torchmodel(
405
407
if file_name is None and ov_config .save_onnx_model
406
408
else Path (ov_file_name ).with_suffix (".onnx" )
407
409
)
410
+
411
+ task = self .task
412
+ model = self .model
413
+ self .model .config .save_pretrained (save_directory )
414
+ if task .startswith ("text-generation" ):
415
+ onnx_config = onnx_config_class (
416
+ model .config , use_past = model .config .use_cache , use_past_in_inputs = model .config .use_cache
417
+ )
418
+ if model .config .use_cache :
419
+ task = "text-generation-with-past"
420
+ else :
421
+ onnx_config = onnx_config_class (model .config )
422
+
423
+ stateful = ensure_stateful_is_available () and ensure_export_task_support_stateful (task )
424
+
408
425
if weights_only :
409
- if getattr (self .model .config , "tie_word_embeddings" , True ):
410
- # to fix problem with shared embedding weights in nncf compress_weights()
411
- self .model .tie_weights ()
412
- compressed_model = compress_weights (self .model )
413
- self .model = compressed_model
426
+ if stateful :
427
+ # patch model before weight compression
428
+ model = patch_model_with_bettertransformer (model )
429
+
430
+ dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
431
+ device = get_model_device (model )
432
+ dummy_inputs = tree_map (
433
+ lambda value : value .to (device ) if isinstance (value , torch .Tensor ) else value , dummy_inputs
434
+ )
435
+ check_dummy_inputs_are_allowed (model , dummy_inputs )
436
+
437
+ nncf .compress_weights (model , dataset = nncf .Dataset ([dummy_inputs ]))
414
438
else :
439
+ if stateful :
440
+ logger .warn (
441
+ "Quantization algorithm does not support optimized stateful models. "
442
+ "The original model without optimization will be quantized and export."
443
+ )
444
+ stateful = False
445
+
415
446
calibration_dataloader = self ._get_calibration_dataloader (
416
447
calibration_dataset = calibration_dataset ,
417
448
batch_size = batch_size ,
@@ -423,22 +454,10 @@ def _quantize_torchmodel(
423
454
ov_config .add_input_info (model_inputs )
424
455
nncf_config = NNCFConfig .from_dict (ov_config .__dict__ )
425
456
nncf_config = register_default_init_args (nncf_config , calibration_dataloader )
426
- controller , compressed_model = create_compressed_model (
427
- self .model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
428
- )
429
- compressed_model = controller .strip (do_copy = False )
430
-
431
- task = self .task
432
- model = self .model
433
- self .model .config .save_pretrained (save_directory )
434
- if task .startswith ("text-generation" ):
435
- onnx_config = onnx_config_class (
436
- model .config , use_past = model .config .use_cache , use_past_in_inputs = model .config .use_cache
457
+ controller , model = create_compressed_model (
458
+ model , nncf_config , wrap_inputs_fn = wrap_nncf_model_inputs_with_objwalk
437
459
)
438
- if model .config .use_cache :
439
- task = "text-generation-with-past"
440
- else :
441
- onnx_config = onnx_config_class (model .config )
460
+ model = controller .strip (do_copy = False )
442
461
443
462
model_path = save_directory / (onnx_file_name if ov_config .save_onnx_model else ov_file_name )
444
463
onnx_path = save_directory / onnx_file_name
@@ -447,7 +466,8 @@ def _quantize_torchmodel(
447
466
opset = max (opset , MIN_ONNX_QDQ_OPSET )
448
467
kwargs = {}
449
468
if not ov_config .save_onnx_model :
450
- kwargs = {"stateful" : ensure_export_task_support_stateful (task )}
469
+ kwargs = {"stateful" : stateful }
470
+
451
471
_ , _ , is_onnx = export_fn (model = model , config = onnx_config , output = model_path , opset = opset , ** kwargs )
452
472
if is_onnx :
453
473
# Load and save the compressed model
0 commit comments