diff --git a/docs/source/openvino/export.mdx b/docs/source/openvino/export.mdx index 1d0c534193..441614402e 100644 --- a/docs/source/openvino/export.mdx +++ b/docs/source/openvino/export.mdx @@ -31,7 +31,8 @@ Check out the help for more options: ```text usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code] - [--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8,f8e4m3,f8e5m2}] + [--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] + [--quant-mode {int8,f8e4m3,f8e5m2,nf4_f8e4m3,nf4_f8e5m2,int4_f8e4m3,int4_f8e5m2}] [--library {transformers,diffusers,timm,sentence_transformers,open_clip}] [--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym] [--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}] @@ -67,7 +68,7 @@ Optional arguments: on your local machine arbitrary code present in the model repository. --weight-format {fp32,fp16,int8,int4,mxfp4,nf4} The weight format of the exported model. - --quant-mode {int8,f8e4m3,f8e5m2} + --quant-mode {int8,f8e4m3,f8e5m2,nf4_f8e4m3,nf4_f8e5m2,int4_f8e4m3,int4_f8e5m2} Quantization precision mode. This is used for applying full model quantization including activations. --library {transformers,diffusers,timm,sentence_transformers,open_clip} diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 6f3174820a..443a8996ed 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -78,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"): optional_group.add_argument( "--quant-mode", type=str, - choices=["int8", "f8e4m3", "f8e5m2"], + choices=["int8", "f8e4m3", "f8e5m2", "nf4_f8e4m3", "nf4_f8e5m2", "int4_f8e4m3", "int4_f8e5m2"], default=None, help=( "Quantization precision mode. This is used for applying full model quantization including activations. " @@ -352,23 +352,7 @@ def run(self): if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4": quantization_config = get_default_int4_config(self.args.model) else: - is_int8 = self.args.weight_format == "int8" - quantization_config = { - "bits": 8 if is_int8 else 4, - "ratio": 1.0 if is_int8 else (self.args.ratio or _DEFAULT_4BIT_CONFIG["ratio"]), - "sym": self.args.sym or False, - "group_size": -1 if is_int8 else self.args.group_size, - "all_layers": None if is_int8 else self.args.all_layers, - "dataset": self.args.dataset, - "num_samples": self.args.num_samples, - "quant_method": "awq" if self.args.awq else "default", - "sensitivity_metric": self.args.sensitivity_metric, - "scale_estimation": self.args.scale_estimation, - "gptq": self.args.gptq, - "lora_correction": self.args.lora_correction, - "weight_format": self.args.weight_format, - "backup_precision": self.args.backup_precision, - } + quantization_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG) if quantization_config.get("dataset", None) is not None: quantization_config["trust_remote_code"] = self.args.trust_remote_code @@ -378,16 +362,24 @@ def run(self): raise ValueError( "Dataset is required for full quantization. Please provide it with --dataset argument." ) - quantization_config = { - "weight_format": self.args.quant_mode, - "activation_format": self.args.quant_mode, - "bits": 8, - "sym": self.args.sym or False, - "dataset": self.args.dataset, - "num_samples": self.args.num_samples, - "smooth_quant_alpha": self.args.smooth_quant_alpha, - "trust_remote_code": self.args.trust_remote_code, - } + + if self.args.quant_mode in ["nf4_f8e4m3", "nf4_f8e5m2", "int4_f8e4m3", "int4_f8e5m2"]: + wc_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG) + wc_dtype, q_dtype = self.args.quant_mode.split("_") + wc_config["dtype"] = wc_dtype + + q_config = prepare_q_config(self.args) + q_config["dtype"] = q_dtype + + quantization_config = { + "weight_quantization_config": wc_config, + "full_quantization_config": q_config, + "num_samples": self.args.num_samples, + "dataset": self.args.dataset, + "trust_remote_code": self.args.trust_remote_code, + } + else: + quantization_config = prepare_q_config(self.args) ov_config = OVConfig(quantization_config=quantization_config) quantization_config = ov_config.quantization_config if ov_config else None @@ -486,3 +478,35 @@ def run(self): variant=self.args.variant, # **input_shapes, ) + + +def prepare_wc_config(args, default_configs): + is_int8 = args.weight_format == "int8" + return { + "bits": 8 if is_int8 else 4, + "ratio": 1.0 if is_int8 else (args.ratio or default_configs["ratio"]), + "sym": args.sym or False, + "group_size": -1 if is_int8 else args.group_size, + "all_layers": None if is_int8 else args.all_layers, + "dataset": args.dataset, + "num_samples": args.num_samples, + "quant_method": "awq" if args.awq else "default", + "sensitivity_metric": args.sensitivity_metric, + "scale_estimation": args.scale_estimation, + "gptq": args.gptq, + "lora_correction": args.lora_correction, + "dtype": args.weight_format, + "backup_precision": args.backup_precision, + } + + +def prepare_q_config(args): + return { + "dtype": args.quant_mode, + "bits": 8, + "sym": args.sym or False, + "dataset": args.dataset, + "num_samples": args.num_samples, + "smooth_quant_alpha": args.smooth_quant_alpha, + "trust_remote_code": args.trust_remote_code, + } diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index b1651db078..8e0f12b747 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -84,6 +84,7 @@ "OVQuantizationConfig", "OVWeightQuantizationConfig", "OVDynamicQuantizationConfig", + "OVMixedQuantizationConfig", ] ) else: @@ -94,6 +95,7 @@ "OVQuantizationConfig", "OVWeightQuantizationConfig", "OVDynamicQuantizationConfig", + "OVMixedQuantizationConfig", ] ) @@ -272,6 +274,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_openvino_and_nncf_objects import ( OVDynamicQuantizationConfig, + OVMixedQuantizationConfig, OVQuantizationConfig, OVQuantizer, OVTrainingArguments, @@ -280,6 +283,7 @@ else: from .openvino import ( OVDynamicQuantizationConfig, + OVMixedQuantizationConfig, OVQuantizationConfig, OVQuantizer, OVTrainingArguments, diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index d3142ad802..71eeb11f56 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -55,7 +55,13 @@ from .trainer import OVTrainer -from .configuration import OVConfig, OVDynamicQuantizationConfig, OVQuantizationConfig, OVWeightQuantizationConfig +from .configuration import ( + OVConfig, + OVDynamicQuantizationConfig, + OVMixedQuantizationConfig, + OVQuantizationConfig, + OVWeightQuantizationConfig, +) from .modeling import ( OVModelForAudioClassification, OVModelForAudioFrameClassification, diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 966ab57c51..ad6b14aa2a 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import torch from transformers.utils.quantization_config import QuantizationConfigMixin @@ -263,25 +263,18 @@ class OVQuantizationConfigBase(QuantizationConfigMixin): def __init__( self, - bits: int = 8, - sym: bool = False, - ignored_scope: Optional[dict] = None, + ignored_scope: Optional[Union[dict, "nncf.IgnoredScope"]] = None, num_samples: Optional[int] = None, - dataset: Optional[Optional[Union[str, List[str]]]] = None, + dataset: Optional[Union[str, List[str]]] = None, tokenizer: Optional[str] = None, processor: Optional[str] = None, trust_remote_code: bool = False, - weight_format: Optional[str] = None, **kwargs, ): """ Args: - bits (`int`, defaults to 8): - The number of bits to quantize to. - sym (`bool`, defaults to `False`): - Whether to use symmetric quantization. - ignored_scope (`dict`, *optional*): - An ignored scope that defines a list of model nodes to be ignored during quantization. Dictionary + ignored_scope (`dict` or `nncf.IgnoredScope`, *optional*): + An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. num_samples (`int`, *optional*): The maximum number of samples composing the calibration dataset. @@ -295,18 +288,12 @@ def __init__( Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository. - weight_format (`str`, *optional*): - Data format weights are compressed to. """ - self.bits = bits - self.sym = sym self.num_samples = num_samples self.dataset = dataset self.tokenizer = tokenizer self.processor = processor self.trust_remote_code = trust_remote_code - self.weight_format = weight_format - if isinstance(ignored_scope, nncf.IgnoredScope): ignored_scope = ignored_scope.__dict__ self.ignored_scope = ignored_scope @@ -326,6 +313,9 @@ def get_ignored_scope_instance(self) -> "nncf.IgnoredScope": return nncf.IgnoredScope() return nncf.IgnoredScope(**copy.deepcopy(self.ignored_scope)) + def clone(self): + return copy.deepcopy(self) + @dataclass class OVWeightQuantizationConfig(OVQuantizationConfigBase): @@ -370,7 +360,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): sensitivity_metric (`str`, *optional*): The sensitivity metric for assigning quantization precision to layers. In order to preserve the accuracy of the model, the more sensitive layers receives a higher precision. - ignored_scope (`dict`, *optional*): + ignored_scope (`dict` or `nncf.IgnoredScope`, *optional*): An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. num_samples (`int`, *optional*): @@ -389,8 +379,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): scale_estimation (`bool`, *optional*): Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and compressed layers. Providing a dataset is required to run scale estimation. - weight_format (`str`, *optional*): - Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4']. + dtype (`str`, *optional*): + Data type weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4']. qptq (`bool`, *optional*): Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the difference between activations of a compressed and original layer. Dataset is required to run GPTQ. @@ -424,11 +414,11 @@ def __init__( ratio: float = 1.0, all_layers: Optional[bool] = None, sensitivity_metric: Optional[str] = None, - ignored_scope: Optional[dict] = None, + ignored_scope: Optional[Union[dict, "nncf.IgnoredScope"]] = None, num_samples: Optional[int] = None, quant_method: Union[str, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT, scale_estimation: bool = None, - weight_format: Optional[str] = None, + dtype: Optional[str] = None, gptq: bool = None, processor: Optional[str] = None, lora_correction: bool = None, @@ -436,16 +426,15 @@ def __init__( **kwargs, ): super().__init__( - bits=bits, - sym=sym, ignored_scope=ignored_scope, num_samples=num_samples, dataset=dataset, tokenizer=tokenizer, processor=processor, trust_remote_code=trust_remote_code, - weight_format=weight_format, ) + self.bits = bits + self.sym = sym self.group_size = group_size or (-1 if bits == 8 else 128) self.ratio = ratio self.all_layers = all_layers @@ -455,6 +444,13 @@ def __init__( self.gptq = gptq self.lora_correction = lora_correction self.backup_precision = backup_precision + if kwargs.get("weight_format") is not None: + logger.warning( + "The `weight_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. " + "Please use `dtype` instead." + ) + dtype = kwargs.get("weight_format") + self.dtype = dtype self.post_init() def post_init(self): @@ -493,10 +489,18 @@ def post_init(self): "quantization algorithm is selected and compression ratio is 1.0." ) + if self.dtype in ["int4", "int8"]: + bits = 4 if self.dtype == "int4" else 8 + if self.bits is not None and self.bits != bits: + logger.warning( + f"Overriding `bits` parameter to the value `bits`={bits} to match the given {self.dtype} `dtype`." + ) + self.bits = bits + if self.bits not in [4, 8]: raise ValueError(f"Only support quantization to [4,8] bits but found {self.bits}") - if self.bits == 8: + if self.bits == 8 and self.dtype: if self.ratio != 1: raise ValueError( f"For 8-bit quantization, `ratio` is expected to be set to 1.0, but was set to {self.ratio}" @@ -542,29 +546,61 @@ def post_init(self): if self.processor is not None and not isinstance(self.processor, str): raise ValueError(f"Processor is expected to be a string, but found {self.processor}") - if self.weight_format is None: - self.weight_format = "int4" if self.bits == 4 else "int8" - if self.weight_format not in ["int4", "int8", "mxfp4", "nf4"]: + if self.dtype is None: + self.dtype = "int4" if self.bits == 4 else "int8" + if self.dtype not in ["int4", "int8", "mxfp4", "nf4"]: raise ValueError( - f"Weight format must be one of the following: ['int4', 'int8', 'mxfp4', 'nf4'], but found: {self.weight_format}." + f"Weights quantization data type must be one of the following: ['int4', 'int8', 'mxfp4', 'nf4'], but found: {self.dtype}." ) - if self.weight_format in ["mxfp4", "nf4"]: + if self.dtype in ["mxfp4", "nf4"]: if self.bits != 4: raise ValueError( - f"When applying weight compression with '{self.weight_format}' weight format, the `bits` parameter must be set to 4, but found {self.bits}" + f"When applying weight compression with '{self.dtype}' data type, the `bits` parameter must be set to 4, but found {self.bits}" ) - if self.weight_format == "mxfp4": + if self.dtype == "mxfp4": if self.quant_method == OVQuantizationMethod.AWQ: - raise ValueError("The AWQ algorithm is not supported for 'mxpf4' weight format") + raise ValueError("The AWQ algorithm is not supported for 'mxpf4' data type") if self.scale_estimation: - raise ValueError("The Scale Estimation algorithm is not supported for 'mxpf4' weight format") + raise ValueError("The Scale Estimation algorithm is not supported for 'mxpf4' data type") if self.gptq: - raise ValueError("The GPTQ algorithm is not supported for 'mxfp4' weight format") + raise ValueError("The GPTQ algorithm is not supported for 'mxfp4' data type") if self.lora_correction: - raise ValueError("The LoRA Correction algorithm is not supported for 'mxfp4' weight format") + raise ValueError("The LoRA Correction algorithm is not supported for 'mxfp4' data type") if self.gptq and self.lora_correction: raise ValueError("The GPTQ and LoRA Correction algorithms can't be applied simultaneously") + def to_nncf_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary with the variables that are ready to use for nncf.quantize() call. + """ + + signed_bitness = {4: "int4", 8: "int8"} + mode = self.dtype if self.dtype else signed_bitness[self.bits] + if mode in signed_bitness.values(): + mode += "_sym" if self.sym else "_asym" + if mode == "mxfp4": + mode = "e2m1" + mode = nncf.CompressWeightsMode(mode) + + awq = True if self.quant_method == OVQuantizationMethod.AWQ else None + sensitivity_metric = nncf.SensitivityMetric(self.sensitivity_metric) if self.sensitivity_metric else None + backup_mode = nncf.BackupMode(self.backup_precision) if self.backup_precision else None + result = { + "mode": mode, + "ratio": self.ratio, + "group_size": self.group_size, + "ignored_scope": self.get_ignored_scope_instance(), + "all_layers": self.all_layers, + "sensitivity_metric": sensitivity_metric, + "subset_size": self.num_samples or 128, + "awq": awq, + "scale_estimation": self.scale_estimation, + "gptq": self.gptq, + "lora_correction": self.lora_correction, + "backup_mode": backup_mode, + } + return result + @dataclass class OVDynamicQuantizationConfig(OVWeightQuantizationConfig): @@ -593,8 +629,8 @@ def __init__( self, bits: int = 8, sym: bool = False, - ignored_scope: Optional[dict] = None, - num_samples: Optional[int] = 300, + ignored_scope: Optional[Union[dict, "nncf.IgnoredScope"]] = None, + num_samples: Optional[int] = 128, model_type: str = "transformer", fast_bias_correction: bool = True, overflow_fix: str = "disable", @@ -603,8 +639,7 @@ def __init__( processor: Optional[str] = None, trust_remote_code: bool = False, smooth_quant_alpha: Optional[float] = None, - weight_format: Optional[str] = "int8", - activation_format: Optional[str] = "int8", + dtype: Optional[str] = "int8", **kwargs, ): """ @@ -616,7 +651,7 @@ def __init__( The number of bits to quantize to. sym (`bool`, defaults to `False`): Whether to use symmetric quantization on the activations. Symmetric quantization will be applied on the weights in any case. - ignored_scope (`dict`, *optional*): + ignored_scope (`dict` or `nncf.IgnoredScope`, *optional*): An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. num_samples (`int`, *optional*): @@ -649,33 +684,33 @@ def __init__( smooth_quant_alpha (`float`, *optional*): SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and reduces quantization error. - weight_format (`str`, defaults to "int8"): - Data format weights are quantized to. Possible values: ['int8', 'f8e4m3', 'f8e5m2']. - activation_format (`str`, defaults to "int8"): - Data format activations are compressed to. Possible values: ['int8', 'f8e4m3', 'f8e5m2']. + dtype (`str`, defaults to "int8"): + Data type activations are compressed to. Possible values: ['int8', 'f8e4m3', 'f8e5m2']. """ super().__init__( - bits=bits, - sym=sym, ignored_scope=ignored_scope, num_samples=num_samples, dataset=dataset, tokenizer=tokenizer, processor=processor, trust_remote_code=trust_remote_code, - weight_format=weight_format, ) + self.bits = bits + self.sym = sym self.model_type = model_type self.fast_bias_correction = fast_bias_correction self.overflow_fix = overflow_fix self.smooth_quant_alpha = smooth_quant_alpha - self.activation_format = activation_format - - f8_formats = ["f8e4m3", "f8e5m2"] - if self.activation_format in f8_formats and self.weight_format in f8_formats: - logger.info( - f"{self.activation_format} for activations and {self.weight_format} weights were found. A symmetrical scheme will be used." + if kwargs.get("activation_format") is not None: + logger.warning( + "The `activation_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. " + "Please use `dtype` instead." ) + dtype = kwargs.get("activation_format") + self.dtype = dtype + + f8_dtypes = ["f8e4m3", "f8e5m2"] + if self.dtype in f8_dtypes: self.sym = True self.post_init() @@ -696,11 +731,46 @@ def post_init(self): if self.bits != 8: raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}") - if self.smooth_quant_alpha is not None and not (0 <= self.smooth_quant_alpha <= 1): + if self.smooth_quant_alpha is not None and ( + self.smooth_quant_alpha != -1 and not (0 <= self.smooth_quant_alpha <= 1) + ): raise ValueError( - f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}" + f"SmoothQuant alpha parameter can equal -1 or be in range [0, 1], but found {self.smooth_quant_alpha}" + ) + + def to_nncf_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary with the variables that are ready to use for nncf.compress_weights() call. + """ + + preset = "performance" if self.sym else "mixed" + advanced_parameters_dict = {"overflow_fix": self.overflow_fix} + if self.smooth_quant_alpha: + advanced_parameters_dict["smooth_quant_alphas"] = {"matmul": self.smooth_quant_alpha} + + mode_map = {"f8e4m3": "fp8_e4m3", "f8e5m2": "fp8_e5m2"} + mode = mode_map.get(self.dtype) + + preset = nncf.QuantizationPreset(preset) + model_type = nncf.ModelType(self.model_type) + advanced_parameters = nncf.AdvancedQuantizationParameters( + overflow_fix=advanced_parameters_dict["overflow_fix"], + ) + if "smooth_quant_alphas" in advanced_parameters_dict: + advanced_parameters.smooth_quant_alphas = nncf.AdvancedSmoothQuantParameters( + **advanced_parameters_dict["smooth_quant_alphas"] ) + return { + "mode": mode, + "preset": preset, + "subset_size": self.num_samples or 128, + "fast_bias_correction": self.fast_bias_correction, + "model_type": model_type, + "ignored_scope": self.get_ignored_scope_instance(), + "advanced_parameters": advanced_parameters, + } + class OVConfig(BaseConfig): CONFIG_NAME = "openvino_config.json" @@ -719,13 +789,20 @@ def __init__( self.save_onnx_model = save_onnx_model self.optimum_version = kwargs.pop("optimum_version", None) if isinstance(quantization_config, dict): - quantization_config = self._quantization_config_from_dict(quantization_config) + quantization_config = self.quantization_config_from_dict(quantization_config) self.quantization_config = quantization_config self.compression = kwargs.get( "compression", None ) # A field for backward-compatability of training-time compression parameters if self.quantization_config is not None: - self.dtype = self.quantization_config.weight_format + if isinstance(self.quantization_config, (OVWeightQuantizationConfig, OVQuantizationConfig)): + self.dtype = self.quantization_config.dtype + elif isinstance(self.quantization_config, OVMixedQuantizationConfig): + wc_dtype = self.quantization_config.weight_quantization_config.dtype + q_dtype = self.quantization_config.full_quantization_config.dtype + self.dtype = f"{wc_dtype}_{q_dtype}" + else: + raise ValueError(f"Unsupported type of quantization config: {type(self.quantization_config)}") else: self.dtype = dtype @@ -740,7 +817,9 @@ def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False): ] @staticmethod - def _quantization_config_from_dict(quantization_config: dict) -> OVQuantizationConfigBase: + def quantization_config_from_dict(quantization_config: dict) -> OVQuantizationConfigBase: + if "weight_quantization_config" in quantization_config and "full_quantization_config" in quantization_config: + return OVMixedQuantizationConfig.from_dict(quantization_config) wq_args = inspect.getfullargspec(OVWeightQuantizationConfig.__init__).args q_args = inspect.getfullargspec(OVQuantizationConfig.__init__).args weight_only = quantization_config.pop("weight_only", None) @@ -782,3 +861,108 @@ def to_dict(self) -> Dict[str, Any]: def to_diff_dict(self) -> Dict[str, Any]: return self._to_dict_safe(to_diff_dict=True) + + +class OVMixedQuantizationConfig(OVQuantizationConfigBase): + def __init__( + self, + weight_quantization_config: Union[OVWeightQuantizationConfig, dict], + full_quantization_config: Union[OVQuantizationConfig, dict], + ignored_scope: Optional[Union[dict, "nncf.IgnoredScope"]] = None, + num_samples: Optional[int] = None, + dataset: Optional[Union[str, List[str]]] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Configuration class for mixed quantization where we separately quantize: + (1) weights of weighted layers to the precision given in the `weight_quantization_config`, and + (2) weights and activations of other possible layers; precision is given in the `full_quantization_config`. + + By default, weights of all weighted layers are quantized in the first step. In the second step activations of + weighted and non-weighted layers are quantized. If some layers are instructed to be ignored in the first step + with `weight_quantization_config.ignored_scope` parameter, both weights and activations of these layers are + quantized to the precision given in the `full_quantization_config`. + + Args: + weight_quantization_config (`OVWeightQuantizationConfig` or `dict`): + Configuration related to weight quantization. + full_quantization_config (`OVQuantizationConfig` or `dict`): + Configuration related to full quantization. + ignored_scope (`dict` or `nncf.IgnoredScope`, *optional*): + An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary + entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. + Ignored scope provided here will be used for both weight and full quantization steps. + num_samples (`int`, *optional*): + The maximum number of samples composing the calibration dataset. + dataset (`str or List[str]`, *optional*): + The dataset used for data-aware optimization with NNCF. + tokenizer (`str`, *optional*): + The tokenizer used to process the dataset. + processor (`str`, *optional*): + A transformers processor used to process the dataset inputs. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be + set for repositories you trust and in which you have read the code, as it will execute on your local + machine arbitrary code present in the model repository. + **kwargs: + """ + self.weight_quantization_config = self._initialize_quantization_config( + weight_quantization_config, OVWeightQuantizationConfig + ) + wqc = self.weight_quantization_config + + self.full_quantization_config = self._initialize_quantization_config( + full_quantization_config, OVQuantizationConfig + ) + fqc = self.full_quantization_config + + if fqc.dtype in ["f8e4m3", "f8e5m2"] and wqc.backup_precision is None: + # Here we simulate FP8 backup weight compression precision through full quantization: during weight + # compression step some weighted layers are kept in original precision and later are compressed to FP8 + # during full precision quantization step. + # The issue with current approach is that if one provides an ignored scope for the full quantization step, + # then the weights of the layers under this ignored scope won't be compressed to FP8. + # TODO: remove once there is support for FP8 weight compression in NNCF + wqc.backup_precision = "none" + + # Pull dataset-related parameters from child configs. This is not the intended use case, but we process it just + # in case user sets those parameters inside child configs only. + num_samples = max((num_samples or 0, wqc.num_samples or 0, fqc.num_samples or 0)) or None + dataset = dataset or wqc.dataset or fqc.dataset + tokenizer = tokenizer or wqc.tokenizer or fqc.tokenizer + processor = processor or wqc.processor or fqc.processor + trust_remote_code = trust_remote_code or wqc.trust_remote_code or fqc.trust_remote_code + super().__init__( + ignored_scope=ignored_scope, + num_samples=num_samples, + dataset=dataset, + tokenizer=tokenizer, + processor=processor, + trust_remote_code=trust_remote_code, + ) + + self.post_init() + + @staticmethod + def _initialize_quantization_config( + config: Union[dict, OVWeightQuantizationConfig, OVQuantizationConfig], + config_type: Type[Union[OVWeightQuantizationConfig, OVQuantizationConfig]], + ): + if isinstance(config, dict): + return config_type.from_dict(config) + elif isinstance(config, config_type): + return config.clone() + else: + raise ValueError( + f"Unsupported type of quantization config. Expected either a dictionary or an instance of " + f"{config_type}, but found: {type(config)}." + ) + + def to_dict(self): + result = super().to_dict() + result["weight_quantization_config"] = self.weight_quantization_config.to_dict() + result["full_quantization_config"] = self.full_quantization_config.to_dict() + return result diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 3902deff4c..932b505b70 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -250,6 +250,11 @@ def fix_op_names_duplicates(model: openvino.runtime.Model): from optimum.intel.openvino.quantization import _weight_only_quantization + if not isinstance(quantization_config, (dict, OVWeightQuantizationConfig)): + raise TypeError( + f"Expected `quantization_config` to be either a dictionary or OVWeightQuantizationConfig object, got {type(quantization_config)}." + ) + model = _weight_only_quantization(model, quantization_config) return model @@ -378,7 +383,7 @@ def _from_pretrained( compile_only = kwargs.get("compile_only", False) - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) model = None if not compile_only: @@ -481,14 +486,14 @@ def from_pretrained( ) @staticmethod - def _prepare_weight_quantization_config( + def _prepare_quantization_config( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, load_in_8bit: bool = False ): # Give default quantization config if not provided and load_in_8bit=True if not quantization_config and load_in_8bit: quantization_config = OVWeightQuantizationConfig(bits=8) elif isinstance(quantization_config, dict): - quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config) + quantization_config = OVConfig.quantization_config_from_dict(quantization_config) return quantization_config diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index ba0d426e90..a61ec2bad8 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -196,7 +196,7 @@ def _from_pretrained( decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name decoder_with_past = None - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) compile_only = kwargs.get("compile_only", False) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 41e4c6e3e9..bff62194ed 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -847,7 +847,7 @@ def _from_pretrained( if quantization_config.get("dataset", None) is not None: quantization_config["trust_remote_code"] = kwargs.get("trust_remote_code", False) - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) enable_compilation = kwargs.pop("compile", True) and not quantization_config diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index f6c4fc37a8..5167513676 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -475,7 +475,7 @@ def _from_pretrained( kwargs[config_key] = value compile_only = kwargs.get("compile_only", False) - quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) if (quantization_config is None or quantization_config.dataset is None) and not compile_only: for name, path in models.items(): if name in kwargs: diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index ccc5bb1b44..53306e1f7f 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -1029,6 +1029,7 @@ def _from_pretrained( ): compile_only = kwargs.get("compile_only", False) + quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) if not compile_only and isinstance(quantization_config, OVQuantizationConfig): model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained( model_id, config, load_in_8bit=False, **kwargs diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 4b2c5ee031..4cbea24b46 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -545,7 +545,7 @@ def _from_pretrained( except Exception: pass - quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) + quantization_config = model_cls._prepare_quantization_config(quantization_config, load_in_8bit) to_quantize = not compile_only and quantization_config is not None if to_quantize: kwargs["compile"] = False diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 2ba74244d8..9c3838cfa9 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -30,8 +30,7 @@ import torch import transformers from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from nncf import CompressWeightsMode, SensitivityMetric -from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters, OverflowFix +from nncf.quantization.advanced_parameters import OverflowFix from nncf.torch import register_module from nncf.torch.initialization import PTInitializingDataLoader from openvino._offline_transformations import compress_quantize_weights_transformation @@ -60,6 +59,7 @@ from ..utils.modeling_utils import get_model_device from .configuration import ( OVConfig, + OVMixedQuantizationConfig, OVQuantizationConfig, OVQuantizationConfigBase, OVQuantizationMethod, @@ -394,7 +394,7 @@ def _quantize_ovbasemodel( raise ValueError("Calibration dataset is required to run hybrid quantization.") if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): # Apply weight-only quantization to all SD submodels except UNet - quantization_config_copy = copy.deepcopy(quantization_config) + quantization_config_copy = quantization_config.clone() quantization_config_copy.dataset = None quantization_config_copy.quant_method = OVQuantizationMethod.DEFAULT sub_model_names = [ @@ -451,10 +451,7 @@ def _quantize_ovbasemodel( else: _weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs) self.model.request = None - else: - if not isinstance(quantization_config, OVQuantizationConfig): - raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") - + elif isinstance(quantization_config, OVQuantizationConfig): if calibration_dataset is None: raise ValueError("Calibration dataset is required to run quantization.") @@ -467,6 +464,15 @@ def _quantize_ovbasemodel( ) self.model.model = quantized_model self.model.request = None + elif isinstance(quantization_config, OVMixedQuantizationConfig): + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run quantization.") + + quantized_model = _mixed_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs) + self.model.model = quantized_model + self.model.request = None + else: + raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") if save_directory is not None: self.model.save_pretrained(save_directory) @@ -973,7 +979,7 @@ def transform_fn(data_item): def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kwargs): # Quantize encoder model # quantization_config.num_samples of audio samples result in more actual model inputs - config = copy.deepcopy(quantization_config) + config = quantization_config.clone() config.num_samples = calibration_dataset[0].get_length() quantized_encoder_model = _full_quantization( self.model.encoder_model, config, calibration_dataset[0], **kwargs @@ -983,7 +989,7 @@ def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kw self.model.encoder.request = None # Quantize decoder model - config = copy.deepcopy(quantization_config) + config = quantization_config.clone() config.num_samples = calibration_dataset[1].get_length() quantized_decoder_model = _full_quantization( self.model.decoder_model, config, calibration_dataset[1], **kwargs @@ -994,7 +1000,7 @@ def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kw if self.model.decoder_with_past_model is not None: # Quantize decoder with past model - config = copy.deepcopy(quantization_config) + config = quantization_config.clone() config.num_samples = calibration_dataset[2].get_length() quantized_decoder_w_p_model = _full_quantization( self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs @@ -1028,44 +1034,15 @@ def _weight_only_quantization( else: dataset = nncf.Dataset(calibration_dataset) - sensitivity_metric = None - if isinstance(config.sensitivity_metric, str): - sensitivity_metric = getattr(SensitivityMetric, config.sensitivity_metric.upper()) - - if config.weight_format == "mxfp4": - mode = CompressWeightsMode.E2M1 - elif config.weight_format == "nf4": - mode = CompressWeightsMode.NF4 - else: - if config.bits == 8: - mode = CompressWeightsMode.INT8_SYM if config.sym else CompressWeightsMode.INT8_ASYM - else: - mode = CompressWeightsMode.INT4_SYM if config.sym else CompressWeightsMode.INT4_ASYM - + wc_kwargs = copy.deepcopy(kwargs) + wc_kwargs.update(config.to_nncf_dict()) compressed_model = nncf.compress_weights( model, - mode=mode, - ratio=config.ratio, - group_size=config.group_size, - all_layers=config.all_layers, - sensitivity_metric=sensitivity_metric, - awq=getattr(config.quant_method, "name", "") == "AWQ" or None, - ignored_scope=config.get_ignored_scope_instance(), dataset=dataset, - subset_size=config.num_samples if config.num_samples else 128, - scale_estimation=config.scale_estimation, - gptq=config.gptq, - lora_correction=config.lora_correction, - backup_mode=None if config.backup_precision is None else nncf.BackupMode(config.backup_precision), - **kwargs, + **wc_kwargs, ) - # If KV cache compression was disabled, remove the disabling flag from the model - if compressed_model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]): - prev_rt_info = compressed_model.get_rt_info("runtime_options").value - if prev_rt_info["KV_CACHE_PRECISION"] == "f16": - prev_rt_info.pop("KV_CACHE_PRECISION") - compressed_model.set_rt_info(prev_rt_info, "runtime_options") + _remove_f16_kv_cache_precision_flag(compressed_model) return compressed_model @@ -1074,37 +1051,16 @@ def _full_quantization( model: openvino.runtime.Model, quantization_config: OVQuantizationConfig, calibration_dataset: nncf.Dataset, + verify_not_optimized: bool = True, **kwargs, ): - _verify_not_optimized(model) - advanced_parameters_kwargs = {} - if quantization_config.smooth_quant_alpha is not None: - advanced_parameters_kwargs["smooth_quant_alphas"] = AdvancedSmoothQuantParameters( - matmul=quantization_config.smooth_quant_alpha - ) - - q_mode_map = { - "f8e4m3": nncf.QuantizationMode.FP8_E4M3, - "f8e5m2": nncf.QuantizationMode.FP8_E5M2, - } + if verify_not_optimized: + _verify_not_optimized(model) + q_kwargs = copy.deepcopy(kwargs) + q_kwargs.update(quantization_config.to_nncf_dict()) + quantized_model = nncf.quantize(model, calibration_dataset=calibration_dataset, **q_kwargs) - if quantization_config.activation_format in q_mode_map: - kwargs.update({"mode": q_mode_map[quantization_config.activation_format]}) - - quantized_model = nncf.quantize( - model, - calibration_dataset, - subset_size=quantization_config.num_samples if quantization_config.num_samples else 128, - ignored_scope=quantization_config.get_ignored_scope_instance(), - model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED, - fast_bias_correction=quantization_config.fast_bias_correction, - advanced_parameters=nncf.AdvancedQuantizationParameters( - overflow_fix=OverflowFix(quantization_config.overflow_fix), - **advanced_parameters_kwargs, - ), - **kwargs, - ) + _remove_f16_kv_cache_precision_flag(quantized_model) return quantized_model @@ -1172,31 +1128,78 @@ def _hybrid_quantization( Returns: The OpenVINO Runtime model with applied hybrid quantization. """ - ops_to_compress = _collect_ops_with_weights(model) - - wc_config = copy.deepcopy(quantization_config) - wc_config.ignored_scope = wc_config.ignored_scope or {} - - wc_ignored_types = ["Convolution"] if any(op.get_type_name() == "Convolution" for op in model.get_ops()) else [] - wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + wc_ignored_types - compressed_model = _weight_only_quantization(model, wc_config, **kwargs) - - ptq_ignored_scope = quantization_config.get_ignored_scope_instance() - ptq_ignored_scope.names += ops_to_compress - - subset_size = quantization_config.num_samples if quantization_config.num_samples else 200 - quantized_model = nncf.quantize( - model=compressed_model, - calibration_dataset=dataset, - model_type=nncf.ModelType.TRANSFORMER, - ignored_scope=ptq_ignored_scope, - # SQ algo should be disabled for MatMul nodes because their weights are already compressed - advanced_parameters=nncf.AdvancedQuantizationParameters( - smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=-1) - ), - subset_size=subset_size, + + wc_config = quantization_config.clone() + wc_config.ignored_scope = {} + if any(op.get_type_name() == "Convolution" for op in model.get_ops()): + wc_config.ignored_scope["types"] = ["Convolution"] + + q_config_ignored_scope = {"names": _collect_ops_with_weights(model)} + q_config = OVQuantizationConfig( + ignored_scope=q_config_ignored_scope, + num_samples=quantization_config.num_samples or 200, + smooth_quant_alpha=-1, **kwargs, ) + + mixed_quantization_config = OVMixedQuantizationConfig( + weight_quantization_config=wc_config, + full_quantization_config=q_config, + ignored_scope=quantization_config.ignored_scope, + **kwargs, + ) + + return _mixed_quantization(model, mixed_quantization_config, dataset, **kwargs) + + +def _mixed_quantization( + model: openvino.Model, + quantization_config: OVMixedQuantizationConfig, + dataset: nncf.Dataset, + **kwargs, +) -> openvino.Model: + """ + Perform mixed precision quantization where we separately quantize: + (1) weights of weighted layers to the precision given in the `quantization_config.weight_quantization_config`, and + (2) weights and activations of other possible layers; precision is given in the `quantization_config.full_quantization_config`. + + By default, weights of all weighted layers are quantized in the first step. In the second step activations of + weighted and non-weighted layers are quantized. If some layers are instructed to be ignored in the first step + with `weight_quantization_config.ignored_scope` parameter, both weights and activations of these layers are + quantized to the precision given in the `full_quantization_config`. + + Args: + model (`openvino.runtime.Model`): + The OpenVINO Runtime model for applying quantization. + quantization_config (`OVMixedQuantizationConfig`): + The configuration containing the parameters related to quantization. + dataset (`nncf.Dataset`): + The dataset used for quantization. + Returns: + The OpenVINO Runtime model with applied quantization. + """ + + def merge_ignored_scopes( + ignored_scope_1: Union[Dict[str, List[str]], None], ignored_scope_2: Union[Dict[str, List[str]], None] + ) -> Dict[str, List[str]]: + if ignored_scope_1 is None: + return copy.deepcopy(ignored_scope_2) if ignored_scope_2 is not None else None + if ignored_scope_2 is None: + return copy.deepcopy(ignored_scope_1) + merged_ignored_scope = {} + for key in set(ignored_scope_1) | set(ignored_scope_2): + merged_ignored_scope[key] = list(set(ignored_scope_1.get(key, []) + ignored_scope_2.get(key, []))) + return merged_ignored_scope + + wc_config = quantization_config.weight_quantization_config.clone() + wc_config.ignored_scope = merge_ignored_scopes(wc_config.ignored_scope, quantization_config.ignored_scope) + wc_dataset = dataset if wc_config.bits != 8 else None + compressed_model = _weight_only_quantization(model, wc_config, wc_dataset, **kwargs) + + q_config = quantization_config.full_quantization_config.clone() + q_config.ignored_scope = merge_ignored_scopes(q_config.ignored_scope, quantization_config.ignored_scope) + quantized_model = _full_quantization(compressed_model, q_config, dataset, verify_not_optimized=False, **kwargs) + return quantized_model @@ -1215,3 +1218,13 @@ def _verify_not_optimized(ov_model): raise RuntimeError(message_template.format(model_weight_compression_config)) elif model_quantization_config is not None: raise RuntimeError(message_template.format(model_quantization_config)) + + +def _remove_f16_kv_cache_precision_flag(model: openvino.Model) -> openvino.Model: + # Remove the KV cache compression disabling flag from the model + if model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]): + prev_rt_info = model.get_rt_info("runtime_options").value + if prev_rt_info["KV_CACHE_PRECISION"] == "f16": + prev_rt_info.pop("KV_CACHE_PRECISION") + model.set_rt_info(prev_rt_info, "runtime_options") + return model diff --git a/optimum/intel/utils/dummy_openvino_and_nncf_objects.py b/optimum/intel/utils/dummy_openvino_and_nncf_objects.py index e646074e1e..4b96d28589 100644 --- a/optimum/intel/utils/dummy_openvino_and_nncf_objects.py +++ b/optimum/intel/utils/dummy_openvino_and_nncf_objects.py @@ -68,3 +68,14 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino", "nncf"]) + + +class OVMixedQuantizationConfig(metaclass=DummyObject): + _backends = ["openvino", "nncf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "nncf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "nncf"]) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 4bd47b535b..6ac70a47bf 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -14,7 +14,7 @@ import subprocess import unittest from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List from parameterized import parameterized from transformers import AutoModelForCausalLM @@ -128,16 +128,70 @@ class OVCLIExportTestCase(unittest.TestCase): "whisper", "int8", "--dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code", - (14, 22, 21) if is_transformers_version("<=", "4.36.0") else (14, 22, 25), - (14, 21, 17) if is_transformers_version("<=", "4.36.0") else (14, 22, 18), + [14, 22, 21] if is_transformers_version("<=", "4.36.0") else [14, 22, 25], + [{"int8": 14}, {"int8": 21}, {"int8": 17}] + if is_transformers_version("<=", "4.36.0") + else [{"int8": 14}, {"int8": 22}, {"int8": 18}], ), ( "text-generation", "llama", "f8e4m3", "--dataset wikitext2 --smooth-quant-alpha 0.9 --trust-remote-code", - (13,), - (16,), + [ + 13, + ], + [ + {"f8e4m3": 16}, + ], + ), + ( + "text-generation", + "llama", + "nf4_f8e4m3", + "--dataset wikitext2 --num-samples 1 --group-size 16 --trust-remote-code --ratio 0.5", + [ + 14, + ], + [ + {"f8e4m3": 11, "nf4": 5}, + ], + ), + ( + "text-generation", + "llama", + "nf4_f8e5m2", + "--dataset wikitext2 --num-samples 1 --group-size 16 --trust-remote-code --sym --ratio 0.5", + [ + 14, + ], + [ + {"f8e5m2": 11, "nf4": 5}, + ], + ), + ( + "text-generation", + "llama", + "int4_f8e4m3", + "--dataset wikitext2 --num-samples 1 --group-size 16 --trust-remote-code --sym --ratio 0.5", + [ + 14, + ], + [ + {"f8e4m3": 11, "int4": 5}, + ], + ), + ( + "text-generation", + "llama", + "int4_f8e5m2", + "--dataset wikitext2 --num-samples 1 --group-size 16 --trust-remote-code", + [ + 13, + ], + [ + {"f8e5m2": 2, "int4": 28}, + ], ), ] @@ -438,8 +492,8 @@ def test_exporters_cli_full_quantization( model_type: str, quant_mode: str, option: str, - expected_fake_nodes: Tuple[int], - expected_low_precision_nodes: Tuple[int], + expected_fake_nodes_per_model: List[int], + expected_num_weight_nodes_per_model: List[Dict[str, int]], ): with TemporaryDirectory() as tmpdir: subprocess.run( @@ -449,18 +503,24 @@ def test_exporters_cli_full_quantization( ) model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(tmpdir) - models = [model] if task == "automatic-speech-recognition": - models = [model.encoder, model.decoder] + submodels = [model.encoder, model.decoder] if model.decoder_with_past is not None: - models.append(model.decoder_with_past) + submodels.append(model.decoder_with_past) else: - expected_fake_nodes = expected_fake_nodes[:-1] - self.assertEqual(len(expected_fake_nodes), len(models)) - for i, model in enumerate(models): - num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model) - self.assertEqual(expected_fake_nodes[i], num_fake_nodes) - self.assertEqual(expected_low_precision_nodes[i], num_weight_nodes[quant_mode]) + expected_num_weight_nodes_per_model = expected_num_weight_nodes_per_model[:-1] + expected_fake_nodes_per_model = expected_fake_nodes_per_model[:-1] + elif "text-generation" in task: + submodels = [model] + else: + raise Exception("Unexpected task.") + + check_compression_state_per_model( + self, + submodels, + expected_num_weight_nodes_per_model, + expected_fake_nodes_per_model, + ) def test_exporters_cli_int4_with_local_model_and_default_config(self): with TemporaryDirectory() as tmpdir: diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 1a84a14151..0e3e0212f2 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -63,6 +63,7 @@ OVSanaPipeline, OVTrainer, OVQuantizationConfig, + OVMixedQuantizationConfig, OVWeightQuantizationConfig, OVDynamicQuantizationConfig, OVModelOpenCLIPForZeroShotImageClassification, @@ -105,29 +106,131 @@ class OVQuantizerTest(unittest.TestCase): ( OVModelForSpeechSeq2Seq, "whisper", - OVQuantizationConfig( + dict( dataset="librispeech", num_samples=1, processor=MODEL_NAMES["whisper"], trust_remote_code=True, - weight_only=False, smooth_quant_alpha=0.95, ), - (14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 22, 25), - (14, 21, 17) if is_transformers_version("<=", "4.42.4") else (14, 22, 18), + [14, 22, 21] if is_transformers_version("<=", "4.36.0") else [14, 22, 25], + [{"int8": 14}, {"int8": 21}, {"int8": 17}] + if is_transformers_version("<=", "4.36.0") + else [{"int8": 14}, {"int8": 22}, {"int8": 18}], ), ( OVModelForCausalLM, "llama", - OVQuantizationConfig( + dict( dataset="wikitext2", num_samples=1, + dtype="f8e4m3", weight_only=False, - weight_format="f8e4m3", - activation_format="f8e4m3", ), - (13,), - (16,), + [ + 13, + ], + [ + {"f8e4m3": 16}, + ], + ), + ( + OVModelForCausalLM, + "llama", + dict( + weight_quantization_config=dict(bits=4, dtype="nf4", group_size=16, weight_only=True, ratio=0.5), + full_quantization_config=dict(dtype="f8e4m3", weight_only=False), + dataset="wikitext2", + num_samples=1, + ), + [ + 14, + ], + [ + {"f8e4m3": 11, "nf4": 5}, + ], + ), + ( + OVModelForCausalLM, + "llama", + OVMixedQuantizationConfig( + weight_quantization_config=OVWeightQuantizationConfig( + bits=4, + dtype="nf4", + group_size=16, + ratio=0.5, + ignored_scope={"patterns": ["^__module.model.layers.0.self_attn"]}, + ), + full_quantization_config=OVQuantizationConfig( + dtype="f8e4m3", ignored_scope={"patterns": ["^__module.model.layers.0.mlp"]} + ), + ignored_scope={"patterns": ["^__module.model.layers.1.self_attn"]}, + dataset="wikitext2", + num_samples=1, + ), + [ + 7, + ], + [ + {"f8e4m3": 8, "nf4": 2}, + ], + ), + ( + OVModelForCausalLM, + "llama", + OVMixedQuantizationConfig( + weight_quantization_config=OVWeightQuantizationConfig( + bits=4, + dtype="nf4", + group_size=16, + ratio=0.5, + ignored_scope={"patterns": ["^__module.model.layers.0.self_attn"]}, + ), + full_quantization_config=OVQuantizationConfig( + dtype="f8e5m2", ignored_scope={"patterns": ["^__module.model.layers.0.mlp"]} + ), + ignored_scope={"patterns": ["^__module.model.layers.1.self_attn"]}, + dataset="wikitext2", + num_samples=1, + ), + [ + 7, + ], + [ + {"f8e5m2": 8, "nf4": 2}, + ], + ), + ( + OVModelForCausalLM, + "llama", + OVMixedQuantizationConfig( + weight_quantization_config=OVWeightQuantizationConfig(bits=4, group_size=16, ratio=0.5), + full_quantization_config=OVQuantizationConfig(dtype="f8e4m3"), + dataset="wikitext2", + num_samples=1, + ), + [ + 14, + ], + [ + {"f8e4m3": 11, "int4": 10}, + ], + ), + ( + OVModelForCausalLM, + "llama", + OVMixedQuantizationConfig( + weight_quantization_config=OVWeightQuantizationConfig(bits=4, group_size=16), + full_quantization_config=OVQuantizationConfig(dtype="f8e5m2"), + dataset="wikitext2", + num_samples=1, + ), + [ + 13, + ], + [ + {"f8e5m2": 2, "int4": 28}, + ], ), ] @@ -222,35 +325,31 @@ def preprocess_function(examples, tokenizer): @parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET) def test_ov_model_static_quantization_with_auto_dataset( - self, model_cls, model_name, quantization_config, expected_fake_nodes, expected_low_precision_nodes + self, + model_cls, + model_name, + quantization_config, + expected_fake_nodes_per_model, + expected_num_weight_nodes_per_model, ): model_id = MODEL_NAMES[model_name] - quant_mode = quantization_config.activation_format with TemporaryDirectory() as tmp_dir: ov_model = model_cls.from_pretrained(model_id, quantization_config=quantization_config) ov_model.save_pretrained(tmp_dir) if model_cls == OVModelForSpeechSeq2Seq: - models = [ov_model.encoder.model, ov_model.decoder.model] - + submodels = [ov_model.encoder.model, ov_model.decoder.model] if ov_model.decoder_with_past is not None: - models.append(ov_model.decoder_with_past.model) - for model, expected_fake_nodes, expected_lp_nodes in zip( - models, - expected_fake_nodes, - expected_low_precision_nodes, - ): - num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(model) - self.assertEqual(expected_fake_nodes, num_fake_nodes) - self.assertEqual(expected_lp_nodes, num_weight_nodes[quant_mode]) + submodels.append(ov_model.decoder_with_past.model) + else: + expected_num_weight_nodes_per_model = expected_num_weight_nodes_per_model[:-1] + expected_fake_nodes_per_model = expected_fake_nodes_per_model[:-1] input_features = torch.randn((1, 128, 3000), dtype=torch.float32) ov_model.generate(input_features) elif model_cls == OVModelForCausalLM: - num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(ov_model.model) - self.assertEqual(expected_fake_nodes[0], num_fake_nodes) - self.assertEqual(expected_low_precision_nodes[0], num_weight_nodes[quant_mode]) + submodels = [ov_model] tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: @@ -261,6 +360,13 @@ def test_ov_model_static_quantization_with_auto_dataset( else: raise Exception("Unexpected model class.") + check_compression_state_per_model( + self, + submodels, + expected_num_weight_nodes_per_model, + expected_fake_nodes_per_model, + ) + class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( @@ -284,14 +390,14 @@ class OVWeightCompressionTest(unittest.TestCase): OVModelForCausalLM, "gpt2", False, - dict(bits=4, weight_format="mxfp4", group_size=32), + dict(bits=4, dtype="mxfp4", group_size=32), [{"int8": 4, "f4e2m1": 20, "f8e8m0": 20}], ), ( OVModelForCausalLM, "gpt2", False, - dict(bits=4, weight_format="nf4", group_size=32), + dict(bits=4, dtype="nf4", group_size=32), [ { "int8": 4, @@ -854,7 +960,7 @@ def test_ovmodel_4bit_auto_compression_with_config( openvino_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(openvino_config.quantization_config.bits, 4) - self.assertEqual(openvino_config.dtype, quantization_config.weight_format) + self.assertEqual(openvino_config.dtype, quantization_config.dtype) @parameterized.expand(((OVModelForCausalLM, "gpt2"),)) def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type): @@ -1011,7 +1117,7 @@ def test_ovmodel_4bit_dynamic_with_config( model.save_pretrained(tmp_dir) openvino_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(openvino_config.quantization_config.bits, 4) - self.assertEqual(openvino_config.dtype, quantization_config.weight_format) + self.assertEqual(openvino_config.dtype, quantization_config.dtype) class OVQuantizerQATest(unittest.TestCase): @@ -1284,7 +1390,7 @@ def test_config_from_dict(self, quantization_config: dict, config_type: type, wa @parameterized.expand(DEFAULT_CONFIGURATIONS) def test_named_default_configurations(self, config_id: str): custom_configuration = self.DEFAULT_CONFIGURATIONS[config_id] - prepared_config = OVModelForCausalLM._prepare_weight_quantization_config(custom_configuration) + prepared_config = OVModelForCausalLM._prepare_quantization_config(custom_configuration) for field_name, reference_value in custom_configuration.items(): value = prepared_config.__getattribute__(field_name) self.assertEqual(value, reference_value) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index e4c2ede8e9..f8e7a4add1 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest from contextlib import contextmanager -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import numpy as np import openvino as ov @@ -293,14 +293,25 @@ def new_forward( def check_compression_state_per_model( test_case: unittest.TestCase, models: List[Union[ov.Model, OVBaseModel]], - expected_num_weight_nodes_per_model: List[Dict], + expected_num_weight_nodes_per_model: List[Dict[str, int]], + expected_num_fake_nodes_per_model: Optional[List[int]] = None, ): test_case.assertEqual(len(models), len(expected_num_weight_nodes_per_model)) - actual_num_weights_per_model = [] - for submodel, expected_num_weight_nodes in zip(models, expected_num_weight_nodes_per_model): + actual_num_weights_per_model = [{}] * len(models) + actual_num_fake_nodes_per_model = [0] * len(models) + for i, (submodel, expected_num_weight_nodes) in enumerate(zip(models, expected_num_weight_nodes_per_model)): ov_model = submodel if isinstance(submodel, ov.Model) else submodel.model - _, num_weight_nodes = get_num_quantized_nodes(ov_model) + num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(ov_model) expected_num_weight_nodes.update({k: 0 for k in set(num_weight_nodes) - set(expected_num_weight_nodes)}) - actual_num_weights_per_model.append(num_weight_nodes) + + actual_num_weights_per_model[i] = num_weight_nodes + actual_num_fake_nodes_per_model[i] = num_fake_nodes + test_case.assertFalse(ov_model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"])) + + # Check weight nodes test_case.assertEqual(expected_num_weight_nodes_per_model, actual_num_weights_per_model) + + # Check fake nodes + if expected_num_fake_nodes_per_model is not None: + test_case.assertEqual(expected_num_fake_nodes_per_model, actual_num_fake_nodes_per_model)