@@ -78,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
78
78
optional_group .add_argument (
79
79
"--quant-mode" ,
80
80
type = str ,
81
- choices = ["int8" , "f8e4m3" , "f8e5m2" ],
81
+ choices = ["int8" , "f8e4m3" , "f8e5m2" , "nf4_f8e4m3" ],
82
82
default = None ,
83
83
help = (
84
84
"Quantization precision mode. This is used for applying full model quantization including activations. "
@@ -307,7 +307,14 @@ def parse_args(parser: "ArgumentParser"):
307
307
def run (self ):
308
308
from ...exporters .openvino .__main__ import infer_task , main_export , maybe_convert_tokenizers
309
309
from ...exporters .openvino .utils import save_preprocessors
310
- from ...intel .openvino .configuration import _DEFAULT_4BIT_CONFIG , OVConfig , get_default_int4_config
310
+ from ...intel .openvino .configuration import (
311
+ _DEFAULT_4BIT_CONFIG ,
312
+ OVCompressWeightsOptions ,
313
+ OVConfig ,
314
+ OVGeneralQuantizationConfig ,
315
+ OVQuantizeOptions ,
316
+ get_default_int4_config ,
317
+ )
311
318
312
319
if self .args .library is None :
313
320
# TODO: add revision, subfolder and token to args
@@ -342,43 +349,39 @@ def run(self):
342
349
if no_compression_parameter_provided (self .args ) and self .args .weight_format == "int4" :
343
350
quantization_config = get_default_int4_config (self .args .model )
344
351
else :
345
- is_int8 = self .args .weight_format == "int8"
346
- quantization_config = {
347
- "bits" : 8 if is_int8 else 4 ,
348
- "ratio" : 1 if is_int8 else (self .args .ratio or _DEFAULT_4BIT_CONFIG ["ratio" ]),
349
- "sym" : self .args .sym or False ,
350
- "group_size" : - 1 if is_int8 else self .args .group_size ,
351
- "all_layers" : None if is_int8 else self .args .all_layers ,
352
- "dataset" : self .args .dataset ,
353
- "num_samples" : self .args .num_samples ,
354
- "quant_method" : "awq" if self .args .awq else "default" ,
355
- "sensitivity_metric" : self .args .sensitivity_metric ,
356
- "scale_estimation" : self .args .scale_estimation ,
357
- "gptq" : self .args .gptq ,
358
- "lora_correction" : self .args .lora_correction ,
359
- "weight_format" : self .args .weight_format ,
360
- "backup_precision" : self .args .backup_precision ,
361
- }
352
+ quantization_config = prepare_for_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
362
353
363
354
if quantization_config .get ("dataset" , None ) is not None :
364
355
quantization_config ["trust_remote_code" ] = self .args .trust_remote_code
365
356
ov_config = OVConfig (quantization_config = quantization_config )
366
- else :
357
+ elif self . args . quant_mode is not None :
367
358
if self .args .dataset is None :
368
359
raise ValueError (
369
360
"Dataset is required for full quantization. Please provide it with --dataset argument."
370
361
)
371
362
372
- quantization_config = {
373
- "weight_format" : self .args .quant_mode ,
374
- "activation_format" : self .args .quant_mode ,
375
- "bits" : 8 ,
376
- "sym" : self .args .sym or False ,
377
- "dataset" : self .args .dataset ,
378
- "num_samples" : self .args .num_samples ,
379
- "smooth_quant_alpha" : self .args .smooth_quant_alpha ,
380
- "trust_remote_code" : self .args .trust_remote_code ,
381
- }
363
+ if self .args .quant_mode == "nf4_f8e4m3" :
364
+ wc_config = prepare_for_wc_config (self .args , _DEFAULT_4BIT_CONFIG )
365
+ wc_config ["weight_format" ] = "nf4"
366
+ cw_options = OVCompressWeightsOptions .init_with_format (** wc_config )
367
+
368
+ q_config = prepare_for_q_config (self .args )
369
+ q_config ["activation_format" ] = "f8e4m3"
370
+ q_options = OVQuantizeOptions .init_with_format (** q_config )
371
+
372
+ quantization_config = OVGeneralQuantizationConfig .init_with_format (
373
+ bits = 8 ,
374
+ sym = self .args .sym ,
375
+ ignored_scope = None ,
376
+ num_samples = self .args .num_samples ,
377
+ dataset = self .args .dataset ,
378
+ trust_remote_code = self .args .trust_remote_code ,
379
+ weight_format = self .args .weight_format ,
380
+ )
381
+ quantization_config .compress_weights_options = cw_options
382
+ quantization_config .quantize_options = q_options
383
+ else :
384
+ quantization_config = prepare_for_q_config (self .args )
382
385
ov_config = OVConfig (quantization_config = quantization_config )
383
386
384
387
quantization_config = ov_config .quantization_config if ov_config else None
@@ -470,3 +473,36 @@ def run(self):
470
473
library_name = library_name ,
471
474
# **input_shapes,
472
475
)
476
+
477
+
478
+ def prepare_for_wc_config (args , default_configs ):
479
+ is_int8 = args .weight_format == "int8"
480
+ return {
481
+ "bits" : 8 if is_int8 else 4 ,
482
+ "ratio" : 1 if is_int8 else (args .ratio or default_configs ["ratio" ]),
483
+ "sym" : args .sym or False ,
484
+ "group_size" : - 1 if is_int8 else args .group_size ,
485
+ "all_layers" : None if is_int8 else args .all_layers ,
486
+ "dataset" : args .dataset ,
487
+ "num_samples" : args .num_samples ,
488
+ "quant_method" : "awq" if args .awq else "default" ,
489
+ "sensitivity_metric" : args .sensitivity_metric ,
490
+ "scale_estimation" : args .scale_estimation ,
491
+ "gptq" : args .gptq ,
492
+ "lora_correction" : args .lora_correction ,
493
+ "weight_format" : args .weight_format ,
494
+ "backup_precision" : args .backup_precision ,
495
+ }
496
+
497
+
498
+ def prepare_for_q_config (args ):
499
+ return {
500
+ "weight_format" : args .quant_mode ,
501
+ "activation_format" : args .quant_mode ,
502
+ "bits" : 8 ,
503
+ "sym" : args .sym or False ,
504
+ "dataset" : args .dataset ,
505
+ "num_samples" : args .num_samples ,
506
+ "smooth_quant_alpha" : args .smooth_quant_alpha ,
507
+ "trust_remote_code" : args .trust_remote_code ,
508
+ }
0 commit comments