24
24
import transformers
25
25
from accelerate .data_loader import DataLoaderStateMixin
26
26
from datasets import Dataset , load_dataset
27
- from nncf import NNCFConfig , compress_weights
27
+ from nncf import NNCFConfig
28
28
from nncf .torch import create_compressed_model , register_default_init_args , register_module
29
29
from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
30
30
from nncf .torch .initialization import PTInitializingDataLoader
34
34
from transformers import DataCollator , PreTrainedModel , default_data_collator
35
35
from transformers .pytorch_utils import Conv1D
36
36
37
+ from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
37
38
from optimum .exporters .tasks import TasksManager
38
39
from optimum .quantization_base import OptimumQuantizer
39
40
40
41
from ...exporters .openvino import export , export_pytorch_via_onnx
41
- from ...exporters .openvino .stateful import ensure_export_task_support_stateful
42
+ from ...exporters .openvino .model_patcher import patch_model_with_bettertransformer
43
+ from ...exporters .openvino .stateful import ensure_export_task_support_stateful , ensure_stateful_is_available
42
44
from ..utils .constant import _TASK_ALIASES
43
45
from .configuration import OVConfig
44
46
from .modeling_base import OVBaseModel
@@ -344,9 +346,7 @@ def __getattr__(self, attr):
344
346
self .model .model ,
345
347
quantization_dataset ,
346
348
model_type = nncf .ModelType .TRANSFORMER if not kwargs .get ("model_type" ) else kwargs .get ("model_type" ),
347
- fast_bias_correction = True
348
- if not kwargs .get ("fast_bias_correction" )
349
- else kwargs .get ("fast_bias_correction" ),
349
+ fast_bias_correction = True if not kwargs .get ("fast_bias_correction" ) else kwargs .get ("fast_bias_correction" ),
350
350
** kwargs ,
351
351
)
352
352
self .model .model = quantized_model
@@ -388,13 +388,44 @@ def _quantize_torchmodel(
388
388
if file_name is None and quantization_config .save_onnx_model
389
389
else Path (ov_file_name ).with_suffix (".onnx" )
390
390
)
391
+
392
+ task = self .task
393
+ model = self .model
394
+ self .model .config .save_pretrained (save_directory )
395
+ if task .startswith ("text-generation" ):
396
+ onnx_config = onnx_config_class (
397
+ model .config , use_past = model .config .use_cache , use_past_in_inputs = model .config .use_cache
398
+ )
399
+ if model .config .use_cache :
400
+ task = "text-generation-with-past"
401
+ else :
402
+ onnx_config = onnx_config_class (model .config )
403
+
404
+ stateful = ensure_stateful_is_available () and ensure_export_task_support_stateful (task )
405
+
391
406
if weights_only :
392
- if getattr (self .model .config , "tie_word_embeddings" , True ):
393
- # to fix problem with shared embedding weights in nncf compress_weights()
394
- self .model .tie_weights ()
395
- compressed_model = compress_weights (self .model )
396
- self .model = compressed_model
407
+ from torch .utils ._pytree import tree_map
408
+
409
+ if stateful :
410
+ # patch model before weight compression
411
+ model = patch_model_with_bettertransformer (model )
412
+
413
+ dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
414
+ device = self .model .device
415
+ dummy_inputs = tree_map (
416
+ lambda value : value .to (device ) if isinstance (value , torch .Tensor ) else value , dummy_inputs
417
+ )
418
+ check_dummy_inputs_are_allowed (model , dummy_inputs )
419
+
420
+ nncf .compress_weights (self .model , dataset = nncf .Dataset ([dummy_inputs ]))
397
421
else :
422
+ if stateful :
423
+ logger .warn (
424
+ "Quantization algorithm does not support optimized stateful models. "
425
+ "The original model without optimization will be quantized and export."
426
+ )
427
+ stateful = False
428
+
398
429
calibration_dataloader = self ._get_calibration_dataloader (
399
430
calibration_dataset = calibration_dataset ,
400
431
batch_size = batch_size ,
@@ -411,26 +442,14 @@ def _quantize_torchmodel(
411
442
)
412
443
compressed_model = controller .strip (do_copy = False )
413
444
414
- task = self .task
415
- model = self .model
416
- self .model .config .save_pretrained (save_directory )
417
- if task .startswith ("text-generation" ):
418
- onnx_config = onnx_config_class (
419
- model .config , use_past = model .config .use_cache , use_past_in_inputs = model .config .use_cache
420
- )
421
- if model .config .use_cache :
422
- task = "text-generation-with-past"
423
- else :
424
- onnx_config = onnx_config_class (model .config )
425
-
426
445
model_path = save_directory / (onnx_file_name if quantization_config .save_onnx_model else ov_file_name )
427
446
onnx_path = save_directory / onnx_file_name
428
447
export_fn = export if not quantization_config .save_onnx_model else export_pytorch_via_onnx
429
448
opset = min (onnx_config .DEFAULT_ONNX_OPSET , MAX_ONNX_OPSET )
430
449
opset = max (opset , MIN_ONNX_QDQ_OPSET )
431
450
kwargs = {}
432
451
if not quantization_config .save_onnx_model :
433
- kwargs = {"stateful" : ensure_export_task_support_stateful ( task ) }
452
+ kwargs = {"stateful" : stateful }
434
453
_ , _ , is_onnx = export_fn (model = model , config = onnx_config , output = model_path , opset = opset , ** kwargs )
435
454
if is_onnx :
436
455
# Load and save the compressed model
0 commit comments