16
16
import inspect
17
17
import logging
18
18
import os
19
+ from collections import deque
19
20
from pathlib import Path
20
21
from typing import Any , Callable , Dict , Optional , Tuple , Union
21
22
24
25
import torch
25
26
import transformers
26
27
from nncf import CompressWeightsMode , IgnoredScope , NNCFConfig , SensitivityMetric
28
+ from nncf .quantization .advanced_parameters import AdvancedSmoothQuantParameters
27
29
from nncf .torch import create_compressed_model , register_default_init_args , register_module
28
30
from nncf .torch .dynamic_graph .io_handling import wrap_nncf_model_inputs_with_objwalk
29
31
from nncf .torch .initialization import PTInitializingDataLoader
@@ -550,7 +552,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
550
552
551
553
def _weight_only_quantization (
552
554
model : openvino .runtime .Model , quantization_config : Union [OVWeightQuantizationConfig , Dict ]
553
- ):
555
+ ) -> openvino . runtime . Model :
554
556
config = quantization_config
555
557
if isinstance (config , dict ):
556
558
config = OVWeightQuantizationConfig .from_dict (quantization_config )
@@ -564,7 +566,8 @@ def _weight_only_quantization(
564
566
565
567
from optimum .gptq .data import get_dataset , prepare_dataset
566
568
567
- dataset = get_dataset (config .dataset , tokenizer , seqlen = 32 )
569
+ nsamples = config .num_samples if config .num_samples else 128
570
+ dataset = get_dataset (config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples )
568
571
dataset = prepare_dataset (dataset )
569
572
570
573
sensitivity_metric = None
@@ -590,4 +593,92 @@ def _weight_only_quantization(
590
593
# awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0
591
594
ignored_scope = ignored_scope ,
592
595
dataset = dataset ,
596
+ # subset_size=config.num_samples if config.num_samples else 128, # TODO : enable from nncf v2.9.0
593
597
)
598
+
599
+
600
+ def _get_operation_const_op (operation , const_port_id : int ):
601
+ node = operation .input_value (const_port_id ).get_node ()
602
+ queue = deque ([node ])
603
+ constant_node = None
604
+ allowed_propagation_types_list = ["Convert" , "FakeQuantize" , "Reshape" ]
605
+
606
+ while len (queue ) != 0 :
607
+ curr_node = queue .popleft ()
608
+ if curr_node .get_type_name () == "Constant" :
609
+ constant_node = curr_node
610
+ break
611
+ if len (curr_node .inputs ()) == 0 :
612
+ break
613
+ if curr_node .get_type_name () in allowed_propagation_types_list :
614
+ queue .append (curr_node .input_value (0 ).get_node ())
615
+
616
+ return constant_node
617
+
618
+
619
+ def _is_embedding (node ) -> bool :
620
+ allowed_types_list = ["f16" , "f32" , "f64" ]
621
+ const_port_id = 0
622
+ input_tensor = node .input_value (const_port_id )
623
+ if input_tensor .get_element_type ().get_type_name () in allowed_types_list :
624
+ const_node = _get_operation_const_op (node , const_port_id )
625
+ if const_node is not None :
626
+ return True
627
+
628
+ return False
629
+
630
+
631
+ def _collect_ops_with_weights (model ):
632
+ ops_with_weights = []
633
+ for op in model .get_ops ():
634
+ if op .get_type_name () == "MatMul" :
635
+ constant_node_0 = _get_operation_const_op (op , const_port_id = 0 )
636
+ constant_node_1 = _get_operation_const_op (op , const_port_id = 1 )
637
+ if constant_node_0 or constant_node_1 :
638
+ ops_with_weights .append (op .get_friendly_name ())
639
+ if op .get_type_name () == "Gather" and _is_embedding (op ):
640
+ ops_with_weights .append (op .get_friendly_name ())
641
+
642
+ return ops_with_weights
643
+
644
+
645
+ def _hybrid_quantization (
646
+ model : openvino .runtime .Model , quantization_config : OVWeightQuantizationConfig , dataset : Dict [str , Any ]
647
+ ) -> openvino .runtime .Model :
648
+ """
649
+ Quantize a model in hybrid mode with NNCF which means that we quantize:
650
+ weights of MatMul and Embedding layers and activations of other layers.
651
+ The optimization specifications defined in `quantization_config`.
652
+
653
+ Args:
654
+ model (`openvino.runtime.Model`):
655
+ The OpenVINO Runtime model for applying hybrid quantization.
656
+ quantization_config (`OVWeightQuantizationConfig`):
657
+ The configuration containing the parameters related to quantization.
658
+ dataset (`Dict[str, Any]`):
659
+ The dataset used for hybrid quantization.
660
+ Returns:
661
+ The OpenVINO Runtime model with applied hybrid quantization.
662
+ """
663
+ ops_to_compress = _collect_ops_with_weights (model )
664
+
665
+ ignored_scope = quantization_config .ignored_scope if isinstance (quantization_config .ignored_scope , dict ) else {}
666
+ ptq_ignored_scope = nncf .IgnoredScope (** ignored_scope )
667
+ ptq_ignored_scope .names += ops_to_compress
668
+
669
+ wc_quantization_config = copy .deepcopy (quantization_config )
670
+ wc_quantization_config .ignored_scope = ignored_scope
671
+ wc_quantization_config .ignored_scope ["types" ] = ignored_scope .get ("types" , []) + ["Convolution" ]
672
+ compressed_model = _weight_only_quantization (model , wc_quantization_config )
673
+
674
+ subset_size = quantization_config .num_samples if quantization_config .num_samples else 200
675
+ quantized_model = nncf .quantize (
676
+ model = compressed_model ,
677
+ calibration_dataset = nncf .Dataset (dataset ),
678
+ model_type = nncf .ModelType .TRANSFORMER ,
679
+ ignored_scope = ptq_ignored_scope ,
680
+ # The SQ algo should be disabled for MatMul nodes because their weights are already compressed
681
+ advanced_parameters = nncf .AdvancedQuantizationParameters (AdvancedSmoothQuantParameters (matmul = - 1 )),
682
+ subset_size = subset_size ,
683
+ )
684
+ return quantized_model
0 commit comments