31
31
from nncf .common .utils .helpers import create_table
32
32
from nncf .experimental .common .tensor_statistics .statistics import WCTensorStatistic
33
33
from nncf .parameters import BackupMode
34
+ from nncf .parameters import CompressionFormat
34
35
from nncf .parameters import CompressWeightsMode
35
36
from nncf .parameters import SensitivityMetric
36
37
from nncf .quantization .advanced_parameters import AdvancedCompressionParameters
45
46
from nncf .quantization .algorithms .weight_compression .weight_lowering import WeightCompressionConfig
46
47
from nncf .scopes import IgnoredScope
47
48
from nncf .scopes import get_ignored_node_names_from_ignored_scope
49
+ from nncf .tensor .definitions import TensorDataType
48
50
49
51
TModel = TypeVar ("TModel" )
50
52
TTensor = TypeVar ("TTensor" )
56
58
CompressWeightsMode .NF4 ,
57
59
CompressWeightsMode .E2M1 ,
58
60
]
61
+ SUPPORTED_DATA_TYPES = [
62
+ TensorDataType .float16 ,
63
+ TensorDataType .bfloat16 ,
64
+ TensorDataType .float32 ,
65
+ TensorDataType .float64 ,
66
+ ]
59
67
60
68
61
69
def get_weight_compression_configuration (
@@ -122,6 +130,7 @@ def check_user_compression_configuration(
122
130
ignored_scope : Optional [IgnoredScope ],
123
131
sensitivity_metric : Optional [SensitivityMetric ],
124
132
backup_mode : Optional [BackupMode ],
133
+ compression_format : Optional [CompressionFormat ],
125
134
advanced_parameters : Optional [AdvancedCompressionParameters ],
126
135
) -> None :
127
136
"""
@@ -172,6 +181,10 @@ def check_user_compression_configuration(
172
181
requires a dataset, but it's not provided."
173
182
raise nncf .ValidationError (msg )
174
183
184
+ if lora_correction and compression_format in [CompressionFormat .FQ , CompressionFormat .FQ_LORA ]:
185
+ msg = "LoRA Correction algorithm is not compatible with FQ and FQ_LORA compression formats."
186
+ raise nncf .ValidationError (msg )
187
+
175
188
176
189
class WeightCompression (Algorithm ):
177
190
"""
@@ -195,6 +208,7 @@ def __init__(
195
208
gptq : bool ,
196
209
lora_correction : bool ,
197
210
backup_mode : BackupMode = BackupMode .INT8_ASYM ,
211
+ compression_format : CompressionFormat = CompressionFormat .DQ ,
198
212
advanced_parameters : Optional [AdvancedCompressionParameters ] = None ,
199
213
):
200
214
"""
@@ -233,6 +247,7 @@ def __init__(
233
247
In this mode, weights are retained in their original precision without any quantization.
234
248
INT8_SYM stands for 8-bit integer symmetric quantization without zero point.
235
249
INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point.
250
+ :param compression_format: Describes the format in which the model is saved after weight compression.
236
251
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
237
252
"""
238
253
super ().__init__ ()
@@ -251,6 +266,7 @@ def __init__(
251
266
self ._gptq = gptq
252
267
self ._lora_correction = lora_correction
253
268
self ._backup_mode = backup_mode
269
+ self ._compression_format = compression_format
254
270
self ._advanced_parameters = (
255
271
advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters ()
256
272
)
@@ -489,7 +505,7 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph)
489
505
continue
490
506
for _ , weight_port_id in self ._backend_entity .get_weight_names_and_port_ids (node , graph ):
491
507
weight_dtype = self ._backend_entity .get_weight_dtype (node , weight_port_id , model , graph )
492
- if weight_dtype . is_float () :
508
+ if weight_dtype in SUPPORTED_DATA_TYPES :
493
509
continue
494
510
weight_shape = self ._backend_entity .get_weight_shape (node , weight_port_id , graph )
495
511
weight_size = reduce (operator .mul , weight_shape , 1 )
@@ -535,7 +551,7 @@ def apply(
535
551
continue
536
552
537
553
weight_dtype = self ._backend_entity .get_weight_dtype (node , weight_port_id , model , graph )
538
- if not weight_dtype . is_float () :
554
+ if weight_dtype not in SUPPORTED_DATA_TYPES :
539
555
continue
540
556
weight_shape = self ._backend_entity .get_weight_shape (node , weight_port_id , graph )
541
557
weight_size = reduce (operator .mul , weight_shape , 1 )
@@ -646,6 +662,7 @@ def apply(
646
662
scales ,
647
663
zero_points ,
648
664
lora_correction_algo ,
665
+ self ._compression_format ,
649
666
)
650
667
651
668
self ._backend_entity .dump_parameters (
@@ -662,6 +679,7 @@ def apply(
662
679
"gptq" : self ._gptq ,
663
680
"lora_correction" : self ._lora_correction ,
664
681
"backup_mode" : self ._backup_mode .value ,
682
+ "compression_format" : self ._compression_format .value ,
665
683
"advanced_parameters" : convert_to_dict_recursively (self ._advanced_parameters ),
666
684
},
667
685
algo_name = "weight_compression" ,
0 commit comments