From 47dbe38595fca3439bebddb375b08ee334695c6e Mon Sep 17 00:00:00 2001 From: Nikolay Date: Mon, 17 Mar 2025 14:00:04 +0100 Subject: [PATCH 1/8] strip in one commit --- nncf/torch/quantization/strip.py | 172 ++++++++++++++++++++++++- tests/torch/helpers.py | 11 ++ tests/torch/ptq/test_fq_lora.py | 62 ++++++++- tests/torch/quantization/test_strip.py | 39 ++++++ tests/torch/requirements.txt | 4 + 5 files changed, 285 insertions(+), 3 deletions(-) diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 2fcec2a6f1c..0a60a5f2e77 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -10,16 +10,37 @@ # limitations under the License. +from typing import List + import numpy as np import torch from torch.quantization.fake_quantize import FakeQuantize import nncf +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled +from nncf.experimental.torch2.commands import PT2InsertionCommand +from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import get_const_node +from nncf.torch.model_graph_manager import get_module_by_name +from nncf.torch.model_graph_manager import split_const_name +from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.layers import AsymmetricLoraQuantizer from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import SymmetricLoraQuantizer from nncf.torch.quantization.layers import SymmetricQuantizer +from nncf.torch.quantization.quantize_functions import TuneRange SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8] @@ -171,6 +192,153 @@ def strip_quantized_model(model: NNCFNetwork): :param model: Compressed model. :return: The modified NNCF network. """ - model = replace_quantizer_to_torch_native_module(model) - model = remove_disabled_quantizers(model) + model_layout = model.nncf.transformation_layout() + transformations = model_layout.transformations + if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]): + model = replace_with_decompressors(model, transformations) + else: + model = replace_quantizer_to_torch_native_module(model) + model = remove_disabled_quantizers(model) return model + + +def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork: + """ + Performs transformation from fake quantize format (FQ) to dequantization one (DQ). + The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value, + while the latter takes a quantized integer representation, dequantizes it, and outputs a floating-point result. + + Mathematically, both methods lead to the same outcome, but due to differences in the order of operations and + rounding errors, the actual results may differ. In particular, this error can occur for values + that are located in the midpoint between two quantized values ("quants"). + + The FQ format may round these values to one "quant", while the DQ format rounds them to another "quant". + To avoid these issues, the compressed representation should be provided not by directly quantizing the input, + but by quantizing a pre-processed, fake-quantized, floating-point representation. + + :param model: Compressed model with Decompressors. + :return: The modified NNCF network. + """ + transformation_layout = TransformationLayout() + model = model.nncf.get_clean_shallow_copy() + graph = model.nncf.get_graph() + + for command in transformations: + quantizer = command.fn + + if len(command.target_points) > 1: + msg = "Command contains more than one target point!" + raise nncf.ValidationError(msg) + + tp = command.target_points[0] + node_with_weight = graph.get_node_by_name(tp.target_node_name) + weight_node = get_const_node(node_with_weight, tp.input_port_id, graph) + + module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) + module = get_module_by_name(module_name, model) + original_weight = getattr(module, weight_attr_name) + + original_dtype = original_weight.dtype + original_shape = original_weight.shape + original_eps = torch.finfo(original_dtype).eps + + qdq_weight = quantizer.quantize(original_weight) + if hasattr(quantizer, "_lspec"): + # Special reshape for LoRA-grouped output + qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) + qdq_weight = qdq_weight.to(original_dtype) + + if isinstance(quantizer, AsymmetricQuantizer): + input_range_safe = abs(quantizer.input_range) + quantizer.eps + input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels) + + integer_dtype = torch.uint8 + + input_low = input_low.to(original_dtype) + input_range = input_range.to(original_dtype) + + scale = input_range / quantizer.level_high + scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) + scale = scale.to(original_dtype) + + zero_point = quantizer.level_low - torch.round(input_low / scale) + zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high) + zero_point = zero_point.to(integer_dtype) + + q_weight = qdq_weight / scale + q_weight = q_weight + zero_point + q_weight = torch.round(q_weight) + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) + q_weight = q_weight.to(integer_dtype) + + if quantizer.num_bits == 8: + decompressor = INT8AsymmetricWeightsDecompressor( + scale=scale, zero_point=zero_point, result_dtype=original_dtype + ) + else: + decompressor = INT4AsymmetricWeightsDecompressor( + scale=scale, + zero_point=zero_point, + compressed_weight_shape=q_weight.shape, + result_shape=original_shape, + result_dtype=original_dtype, + ) + + elif isinstance(quantizer, SymmetricQuantizer): + integer_dtype = torch.int8 + + scale = quantizer.scale / abs(quantizer.level_low) + scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) + scale = scale.to(original_dtype) + + q_weight = qdq_weight / scale + q_weight = torch.round(q_weight) + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) + q_weight = q_weight.to(integer_dtype) + + if quantizer.num_bits == 8: + decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype) + else: + decompressor = INT4SymmetricWeightsDecompressor( + scale=scale, + compressed_weight_shape=q_weight.shape, + result_shape=original_shape, + result_dtype=original_dtype, + ) + + packed_tensor = decompressor.pack_weight(q_weight) + + # sets compressed tensor + compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) + setattr(module, weight_attr_name, compressed_parameter) + + consumer_nodes = graph.get_next_nodes(weight_node) + if len(consumer_nodes) > 1: + for consumer_node in consumer_nodes: + consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name)) + for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False): + if id(param) == id(original_weight): + setattr(consumer_module, name, compressed_parameter) + + if is_experimental_torch_tracing_enabled(): + transformation_layout.register( + PT2InsertionCommand( + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") + ) + ], + decompressor, + ) + ) + else: + decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" + transformation_layout.register( + PTSharedFnInsertionCommand( + [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], + decompressor, + decompressor_name, + ) + ) + + return PTModelTransformer(model).transform(transformation_layout) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 5c3bd5086c6..730a0930c41 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -773,3 +773,14 @@ def _check_pre_post_hooks( assert len(actual_hooks) == len(ref_hooks) for actual_hook, ref_hook in zip(actual_hooks, ref_hooks): assert actual_hook is ref_hook + + +class LinearModel(nn.Module): + def __init__(self, input_shape=List[int]): + super().__init__() + with set_torch_seed(): + self.linear = nn.Linear(input_shape[1], input_shape[0], bias=False) + self.linear.weight.data = torch.randn(input_shape) - 0.5 + + def forward(self, x): + return self.linear(x) diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py index 327733c95b8..fc44b29902d 100644 --- a/tests/torch/ptq/test_fq_lora.py +++ b/tests/torch/ptq/test_fq_lora.py @@ -11,6 +11,10 @@ import pytest import torch +from optimum.exporters.openvino.convert import export_from_model +from optimum.intel.openvino import OVModelForCausalLM +from sentence_transformers import SentenceTransformer +from sentence_transformers import util from transformers import AutoModelForCausalLM from transformers import AutoTokenizer @@ -20,6 +24,44 @@ from nncf.torch.quantization.layers import SymmetricQuantizer as SQ +class ValidationMock: + def __init__(self) -> None: + model_id = "sentence-transformers/all-mpnet-base-v2" + self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + self.model = SentenceTransformer( + model_id, tokenizer_kwargs={"pad_token": self.tokenizer.pad_token}, trust_remote_code=True + ) + + def calculate_similarity(self, gold: str, prediction: str) -> torch.Tensor: + embeddings = self.model.encode([gold, prediction]) + cos_sim = util.cos_sim(embeddings, embeddings) + return torch.mean(cos_sim) + + @property + def validation_ref(self) -> torch.Tensor: + return torch.tensor(1.0) + + +def generate_control_output(model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> torch.Tensor: + control_input = tokenizer("What is Pytorch?", return_tensors="pt") + control_input = control_input.to(model.device) + control_output = model.generate(**control_input, do_sample=False) + return tokenizer.batch_decode(control_output, skip_special_tokens=True)[0] + + +def get_ov_model(model: AutoModelForCausalLM, tmp_path: str) -> OVModelForCausalLM: + model = model.cpu() + export_from_model(model, tmp_path) + + return OVModelForCausalLM.from_pretrained( + model_id=tmp_path, + trust_remote_code=True, + load_in_8bit=False, + compile=True, + ov_config={"KV_CACHE_PRECISION": "f16", "DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"}, + ) + + @pytest.mark.parametrize( "compression_kwargs", (dict(scale_estimation=True, awq=True), dict(scale_estimation=False, awq=False)), @@ -33,7 +75,7 @@ ), ids=["asym", "sym"], ) -def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable, _seed): +def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num_trainable, _seed): model_id = "facebook/opt-125m" device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device) @@ -80,3 +122,21 @@ def test_fq_lora_tuning(mode, backup_mode, compression_kwargs, ref_num_trainable assert first_loss > 8 assert float(loss) < 1 + + tuned_output = generate_control_output(model, tokenizer) + + # Workaround till export from the optimum would be fixed - CVS-164159 + model = model.to(torch.float32) + + model = nncf.strip(model) + stripped_output = generate_control_output(model, tokenizer) + + model = get_ov_model(model, tmp_path) + stripped_ov_output = generate_control_output(model, tokenizer) + + vm = ValidationMock() + tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output) + tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output) + + assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=0.01) + assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=0.01) diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 8344781c3f8..89e97fd3494 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -34,6 +34,7 @@ from tests.common.quantization.data_generators import generate_sweep_data from tests.common.quantization.data_generators import get_quant_len_by_range from tests.torch.helpers import BasicConvTestModel +from tests.torch.helpers import LinearModel from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args from tests.torch.quantization.test_functions import get_test_data @@ -325,3 +326,41 @@ def test_nncf_strip_api(strip_type, do_copy): assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize) assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize) + + +@pytest.mark.parametrize( + ("mode", "torch_dtype", "atol"), + ( + (nncf.CompressWeightsMode.INT4_ASYM, torch.float32, 0.0005), + (nncf.CompressWeightsMode.INT4_ASYM, torch.float16, 0.0005), + (nncf.CompressWeightsMode.INT4_ASYM, torch.bfloat16, 0.01), + (nncf.CompressWeightsMode.INT4_SYM, torch.float32, 0.0005), + (nncf.CompressWeightsMode.INT4_SYM, torch.float16, 0.0005), + (nncf.CompressWeightsMode.INT4_SYM, torch.bfloat16, 0.01), + ), +) +def test_nncf_strip_lora_model(mode, torch_dtype, atol): + input_shape = [1, 16] + model = LinearModel(input_shape=input_shape) + model = model.to(torch_dtype) + with torch.no_grad(): + example = torch.ones(input_shape).to(torch_dtype) + dataset = [example] + + compressed_model = nncf.compress_weights( + model, + ratio=1, + group_size=4, + mode=mode, + backup_mode=None, + dataset=nncf.Dataset(dataset), + all_layers=True, + compression_format=nncf.CompressionFormat.FQ_LORA, + ) + + compressed_output = compressed_model(example) + + strip_compressed_model = nncf.strip(compressed_model, do_copy=True) + stripped_output = strip_compressed_model(example) + + assert torch.allclose(compressed_output, stripped_output, atol=atol) diff --git a/tests/torch/requirements.txt b/tests/torch/requirements.txt index f09a969557c..d4a4ee35acc 100644 --- a/tests/torch/requirements.txt +++ b/tests/torch/requirements.txt @@ -24,3 +24,7 @@ timm==0.9.2 # Required for torch/fx tests torchvision fastdownload==0.0.7 + +sentence-transformers>=2.2.2 +optimum-intel==1.22.0 +optimum==1.24.0 From cc04be9bc782ac9f86fe405f0139d5fcdad30ce2 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Mon, 17 Mar 2025 19:21:26 +0100 Subject: [PATCH 2/8] Introduced strip format for more explicit behavior on strip --- nncf/__init__.py | 1 + nncf/api/compression.py | 15 +++-- nncf/common/composite_compression.py | 5 +- nncf/common/strip.py | 14 +++-- .../tensorflow/quantization/algorithm.py | 5 +- nncf/parameters.py | 17 ++++++ .../weight_compression/torch_backend.py | 9 ++- nncf/tensorflow/algorithm_selector.py | 3 +- nncf/tensorflow/pruning/base_algorithm.py | 5 +- nncf/tensorflow/quantization/algorithm.py | 5 +- nncf/tensorflow/sparsity/base_algorithm.py | 5 +- nncf/tensorflow/strip.py | 10 ++- nncf/torch/algo_selector.py | 3 +- nncf/torch/nncf_network.py | 13 ++-- nncf/torch/pruning/filter_pruning/algo.py | 5 +- nncf/torch/quantization/algo.py | 7 ++- nncf/torch/quantization/strip.py | 61 +++++++++---------- nncf/torch/sparsity/base_algo.py | 5 +- nncf/torch/strip.py | 12 ++-- tests/torch/ptq/test_fq_lora.py | 26 ++++---- tests/torch/quantization/test_strip.py | 44 ++++++------- 21 files changed, 170 insertions(+), 100 deletions(-) diff --git a/nncf/__init__.py b/nncf/__init__.py index ccabe75a410..353ea4174e4 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -40,6 +40,7 @@ from nncf.parameters import ModelType as ModelType from nncf.parameters import QuantizationMode as QuantizationMode from nncf.parameters import SensitivityMetric as SensitivityMetric +from nncf.parameters import StripFormat as StripFormat from nncf.parameters import TargetDevice as TargetDevice from nncf.quantization import QuantizationPreset as QuantizationPreset from nncf.quantization import compress_weights as compress_weights diff --git a/nncf/api/compression.py b/nncf/api/compression.py index 94b8cd27fd0..9ee08a4a094 100644 --- a/nncf/api/compression.py +++ b/nncf/api/compression.py @@ -19,6 +19,7 @@ from nncf.common.statistics import NNCFStatistics from nncf.common.utils.api_marker import api from nncf.common.utils.backend import copy_model +from nncf.parameters import StripFormat TModel = TypeVar("TModel") @@ -236,7 +237,9 @@ def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: need to keep track of statistics on each training batch/step/iteration. """ - def strip_model(self, model: TModel, do_copy: bool = False) -> TModel: + def strip_model( + self, model: TModel, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> TModel: """ Strips auxiliary layers that were used for the model compression, as it's only needed for training. The method is used before exporting the model @@ -244,6 +247,7 @@ def strip_model(self, model: TModel, do_copy: bool = False) -> TModel: :param model: The compressed model. :param do_copy: Modify copy of the model, defaults to False. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ if do_copy: @@ -256,16 +260,17 @@ def prepare_for_export(self) -> None: """ self._model = self.strip_model(self._model) - def strip(self, do_copy: bool = True) -> TModel: # type: ignore[type-var] + def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: # type: ignore[type-var] """ - Returns the model object with as much custom NNCF additions as possible removed - while still preserving the functioning of the model object as a compressed model. + Removes auxiliary layers and operations added during the compression process, resulting in a clean + model ready for deployment. The functionality of the model object is still preserved as a compressed model. :param do_copy: If True (default), will return a copy of the currently associated model object. If False, will return the currently associated model object "stripped" in-place. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ - return self.strip_model(self.model, do_copy) # type: ignore + return self.strip_model(self.model, do_copy, strip_format) # type: ignore @abstractmethod def export_model( diff --git a/nncf/common/composite_compression.py b/nncf/common/composite_compression.py index 7b5d51003ec..709afe97853 100644 --- a/nncf/common/composite_compression.py +++ b/nncf/common/composite_compression.py @@ -23,6 +23,7 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import copy_model from nncf.common.utils.backend import get_backend +from nncf.parameters import StripFormat class CompositeCompressionLoss(CompressionLoss): @@ -276,12 +277,12 @@ def prepare_for_export(self) -> None: stripped_model = ctrl.strip_model(stripped_model) self._model = stripped_model - def strip(self, do_copy: bool = True) -> TModel: # type: ignore + def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: # type: ignore model = self.model if do_copy: model = copy_model(model) for ctrl in self.child_ctrls: - model = ctrl.strip_model(model, do_copy=False) + model = ctrl.strip_model(model, do_copy=False, strip_format=strip_format) return model # type: ignore @property diff --git a/nncf/common/strip.py b/nncf/common/strip.py index 3d5bee0168d..32dc0b3a15b 100644 --- a/nncf/common/strip.py +++ b/nncf/common/strip.py @@ -16,6 +16,7 @@ from nncf.common.utils.api_marker import api from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.parameters import StripFormat from nncf.telemetry.decorator import tracked_function from nncf.telemetry.events import MODEL_BASED_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -25,25 +26,26 @@ @api(canonical_alias="nncf.strip") @tracked_function(category=MODEL_BASED_CATEGORY, extractors=[FunctionCallTelemetryExtractor("nncf.strip")]) -def strip(model: TModel, do_copy: bool = True) -> TModel: +def strip(model: TModel, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: """ - Returns the model object with as much custom NNCF additions as possible removed - while still preserving the functioning of the model object as a compressed model. + Removes auxiliary layers and operations added during the compression process, resulting in a clean + model ready for deployment. The functionality of the model object is still preserved as a compressed model. :param model: The compressed model. :param do_copy: If True (default), will return a copy of the currently associated model object. If False, will return the currently associated model object "stripped" in-place. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ model_backend = get_backend(model) if model_backend == BackendType.TORCH: from nncf.torch.strip import strip as strip_pt - return strip_pt(model, do_copy) # type: ignore + return strip_pt(model, do_copy, strip_format) # type: ignore elif model_backend == BackendType.TENSORFLOW: from nncf.tensorflow.strip import strip as strip_tf - return strip_tf(model, do_copy) # type: ignore + return strip_tf(model, do_copy, strip_format) # type: ignore - msg = f"Method `strip` does not support for {model_backend.value} backend." + msg = f"Method `strip` does not support {model_backend.value} backend." raise nncf.UnsupportedBackendError(msg) diff --git a/nncf/experimental/tensorflow/quantization/algorithm.py b/nncf/experimental/tensorflow/quantization/algorithm.py index de5d16d86cb..aaf04e23eb3 100644 --- a/nncf/experimental/tensorflow/quantization/algorithm.py +++ b/nncf/experimental/tensorflow/quantization/algorithm.py @@ -35,6 +35,7 @@ from nncf.experimental.tensorflow.quantization.init_range import RangeInitializerV2 from nncf.experimental.tensorflow.quantization.init_range import TFRangeInitParamsV2 from nncf.experimental.tensorflow.quantization.quantizers import create_quantizer +from nncf.parameters import StripFormat from nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS from nncf.tensorflow.graph.metatypes.tf_ops import TFOpWithWeightsMetatype from nncf.tensorflow.graph.transformations.commands import TFInsertionCommand @@ -353,7 +354,9 @@ def apply_to(self, model: NNCFNetwork) -> NNCFNetwork: class QuantizationControllerV2(QuantizationController): - def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork: + def strip_model( + self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> NNCFNetwork: if do_copy: model = copy_model(model) return model diff --git a/nncf/parameters.py b/nncf/parameters.py index 9940a70443c..8ba084939cd 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -119,6 +119,23 @@ class CompressionFormat(StrEnum): FQ_LORA = "fake_quantize_with_lora" +@api(canonical_alias="nncf.StripFormat") +class StripFormat(StrEnum): + """ + Describes the format in which model is saved after strip: operation that removes auxiliary layers and + operations added during the compression process, resulting in a clean model ready for deployment. + The functionality of the model object is still preserved as a compressed model. + + :param NATIVE: Returns the model with as much custom NNCF additions as possible, + :param DQ: Replaces FakeQuantize operations with dequantization subgraph and compressed weights in low-bit + precision using fake quantize parameters. This is the default format for deployment of models with compressed + weights. + """ + + NATIVE = "native" + DQ = "dequantize" + + @api(canonical_alias="nncf.BackupMode") class BackupMode(StrEnum): """ diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index ee439e452a5..27db3e6912a 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -283,6 +283,7 @@ def get_fq_insertion_command( orig_weight_shape: Tuple[int, ...], compression_format: CompressionFormat, lora_adapter_rank: int, + is_all_8bit: bool, ) -> PTTransformationCommand: """ Creates a fake quantization insertion command for the given compressed weight. @@ -291,9 +292,11 @@ def get_fq_insertion_command( :param wc_params: Parameters for weight compression. :param orig_weight_shape: The original shape of the weight tensor. :param compression_format: The format of compression. + :param is_all_8bit: Flag indicating if all weights should be compressed to 8-bit. :return: A PTTransformationCommand for inserting fake quantization to the model. """ compression_config = wc_params.compression_config + # default mapping for 4bit weight compression and FQ_LORA format, no need to add lora adapters for 8bit weight mode_vs_schema_map = { CompressWeightsMode.INT4_ASYM: QuantizationScheme.ASYMMETRIC_LORA, CompressWeightsMode.INT4_SYM: QuantizationScheme.SYMMETRIC_LORA, @@ -303,6 +306,9 @@ def get_fq_insertion_command( if compression_format == CompressionFormat.FQ: mode_vs_schema_map[CompressWeightsMode.INT4_ASYM] = QuantizationScheme.ASYMMETRIC mode_vs_schema_map[CompressWeightsMode.INT4_SYM] = QuantizationScheme.SYMMETRIC + if is_all_8bit and compression_format == CompressionFormat.FQ_LORA: + mode_vs_schema_map[CompressWeightsMode.INT8_ASYM] = QuantizationScheme.ASYMMETRIC_LORA + mode_vs_schema_map[CompressWeightsMode.INT8_SYM] = QuantizationScheme.SYMMETRIC_LORA schema = mode_vs_schema_map[compression_config.mode] @@ -469,6 +475,7 @@ def transform_model( model_transformer = PTModelTransformer(model) transformation_layout = TransformationLayout() + is_all_8bit = all(wc_params.compression_config.num_bits == 8 for wc_params in weight_compression_parameters) for wc_params in weight_compression_parameters: compression_config = wc_params.compression_config if compression_config.mode in [ @@ -499,7 +506,7 @@ def transform_model( else: rank = advanced_parameters.lora_adapter_rank command = self.get_fq_insertion_command( - compressed_weight, wc_params, weight.shape, compression_format, rank + compressed_weight, wc_params, weight.shape, compression_format, rank, is_all_8bit ) transformation_layout.register(command) diff --git a/nncf/tensorflow/algorithm_selector.py b/nncf/tensorflow/algorithm_selector.py index c62b144a8da..0bf5b188b17 100644 --- a/nncf/tensorflow/algorithm_selector.py +++ b/nncf/tensorflow/algorithm_selector.py @@ -22,6 +22,7 @@ from nncf.common.statistics import NNCFStatistics from nncf.common.utils.backend import copy_model from nncf.common.utils.registry import Registry +from nncf.parameters import StripFormat from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder from nncf.tensorflow.loss import TFZeroCompressionLoss @@ -60,7 +61,7 @@ def scheduler(self) -> StubCompressionScheduler: def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: return NNCFStatistics() - def strip(self, do_copy: bool = True) -> tf.keras.Model: + def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> tf.keras.Model: model = self.model if do_copy: model = copy_model(self.model) diff --git a/nncf/tensorflow/pruning/base_algorithm.py b/nncf/tensorflow/pruning/base_algorithm.py index c57b745593e..fbfeed2a792 100644 --- a/nncf/tensorflow/pruning/base_algorithm.py +++ b/nncf/tensorflow/pruning/base_algorithm.py @@ -35,6 +35,7 @@ from nncf.config.schemata.defaults import PRUNE_DOWNSAMPLE_CONVS from nncf.config.schemata.defaults import PRUNE_FIRST_CONV from nncf.config.schemata.defaults import PRUNING_INIT +from nncf.parameters import StripFormat from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder from nncf.tensorflow.graph.converter import TFModelConverterFactory from nncf.tensorflow.graph.metatypes.keras_layers import TFBatchNormalizationLayerMetatype @@ -359,6 +360,8 @@ def _calculate_pruned_layers_summary(self) -> List[PrunedLayerSummary]: return pruned_layers_summary - def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model: + def strip_model( + self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> tf.keras.Model: # Transform model for pruning creates copy of the model. return strip_model_from_masks(model, self._op_names) diff --git a/nncf/tensorflow/quantization/algorithm.py b/nncf/tensorflow/quantization/algorithm.py index a64983b0635..6d24193a842 100644 --- a/nncf/tensorflow/quantization/algorithm.py +++ b/nncf/tensorflow/quantization/algorithm.py @@ -56,6 +56,7 @@ from nncf.config.schemata.defaults import QUANTIZE_INPUTS from nncf.config.schemata.defaults import QUANTIZE_OUTPUTS from nncf.config.schemata.defaults import TARGET_DEVICE +from nncf.parameters import StripFormat from nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder from nncf.tensorflow.graph.converter import TFModelConverter @@ -753,7 +754,9 @@ def loss(self) -> CompressionLoss: """ return self._loss - def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model: + def strip_model( + self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> tf.keras.Model: if do_copy: model = copy_model(model) apply_overflow_fix(model, self._op_names) diff --git a/nncf/tensorflow/sparsity/base_algorithm.py b/nncf/tensorflow/sparsity/base_algorithm.py index d18cd44c78c..5abb583129b 100644 --- a/nncf/tensorflow/sparsity/base_algorithm.py +++ b/nncf/tensorflow/sparsity/base_algorithm.py @@ -12,6 +12,7 @@ from nncf.common.compression import BaseCompressionAlgorithmController from nncf.common.sparsity.controller import SparsityController +from nncf.parameters import StripFormat from nncf.tensorflow.graph.metatypes import keras_layers as layer_metatypes from nncf.tensorflow.sparsity.utils import strip_model_from_masks @@ -47,6 +48,8 @@ def __init__(self, target_model, op_names): super().__init__(target_model) self._op_names = op_names - def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model: + def strip_model( + self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> tf.keras.Model: # Transform model for sparsity creates copy of the model. return strip_model_from_masks(model, self._op_names) diff --git a/nncf/tensorflow/strip.py b/nncf/tensorflow/strip.py index 2159ba04cc2..dc9acf0f9a3 100644 --- a/nncf/tensorflow/strip.py +++ b/nncf/tensorflow/strip.py @@ -13,7 +13,9 @@ import tensorflow as tf +import nncf from nncf.common.utils.backend import copy_model +from nncf.parameters import StripFormat from nncf.tensorflow.graph.model_transformer import TFModelTransformer from nncf.tensorflow.graph.transformations.commands import TFOperationWithWeights from nncf.tensorflow.graph.transformations.commands import TFRemovalCommand @@ -28,15 +30,21 @@ from nncf.tensorflow.sparsity.utils import apply_mask -def strip(model: tf.keras.Model, do_copy: bool = True) -> tf.keras.Model: +def strip( + model: tf.keras.Model, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE +) -> tf.keras.Model: """ Implementation of the nncf.strip() function for the TF backend :param model: The compressed model. :param do_copy: If True (default), will return a copy of the currently associated model object. If False, will return the currently associated model object "stripped" in-place. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ + if strip_format != StripFormat.NATIVE: + msg = f"Tensorflow does not support for {strip_format} strip format." + raise nncf.UnsupportedBackendError(msg) if not isinstance(model, tf.keras.Model): return model diff --git a/nncf/torch/algo_selector.py b/nncf/torch/algo_selector.py index 6b6d4214091..f3aec671693 100644 --- a/nncf/torch/algo_selector.py +++ b/nncf/torch/algo_selector.py @@ -19,6 +19,7 @@ from nncf.common.statistics import NNCFStatistics from nncf.common.utils.backend import copy_model from nncf.common.utils.registry import Registry +from nncf.parameters import StripFormat from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.compression_method_api import PTCompressionLoss @@ -81,7 +82,7 @@ def scheduler(self) -> CompressionScheduler: def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics: return NNCFStatistics() - def strip(self, do_copy: bool = True) -> NNCFNetwork: + def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> NNCFNetwork: model = self.model if do_copy: model = copy_model(self.model) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index c11a80d7955..ee2f42f7c20 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -39,6 +39,7 @@ from nncf.common.insertion_point_graph import PostHookInsertionPoint from nncf.common.insertion_point_graph import PreHookInsertionPoint from nncf.common.utils.debug import is_debug +from nncf.parameters import StripFormat from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.telemetry.extractors import FunctionCallTelemetryExtractor @@ -966,12 +967,14 @@ def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]: def set_compression_controller(self, ctrl: CompressionAlgorithmController): self.compression_controller = ctrl - def strip(self, do_copy: bool = True) -> "NNCFNetwork": + def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> "NNCFNetwork": """ - Returns the model object with as much custom NNCF additions as possible removed - while still preserving the functioning of the model object as a compressed model. + Removes auxiliary layers and operations added during the compression process, resulting in a clean + model ready for deployment. The functionality of the model object is still preserved as a compressed model. + :param do_copy: If True (default), will return a copy of the currently associated model object. If False, will return the currently associated model object "stripped" in-place. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ if self.compression_controller is None: @@ -979,8 +982,8 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork": from nncf.torch.quantization.strip import strip_quantized_model model = deepcopy(self._model_ref) if do_copy else self._model_ref - return strip_quantized_model(model) - return self.compression_controller.strip(do_copy) + return strip_quantized_model(model, strip_format=strip_format) + return self.compression_controller.strip(do_copy, strip_format=strip_format) def get_reused_parameters(self): """ diff --git a/nncf/torch/pruning/filter_pruning/algo.py b/nncf/torch/pruning/filter_pruning/algo.py index 82a80d7276b..9977aea8468 100644 --- a/nncf/torch/pruning/filter_pruning/algo.py +++ b/nncf/torch/pruning/filter_pruning/algo.py @@ -45,6 +45,7 @@ from nncf.common.utils.debug import is_debug from nncf.common.utils.os import safe_open from nncf.config.extractors import extract_bn_adaptation_init_params +from nncf.parameters import StripFormat from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.graph.operator_metatypes import PTModuleConv1dMetatype @@ -693,7 +694,9 @@ def _run_batchnorm_adaptation(self): ) self._bn_adaptation.run(self.model) - def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork: + def strip_model( + self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> NNCFNetwork: if do_copy: model = copy_model(model) diff --git a/nncf/torch/quantization/algo.py b/nncf/torch/quantization/algo.py index 2707e454d80..28bcb3ccaff 100644 --- a/nncf/torch/quantization/algo.py +++ b/nncf/torch/quantization/algo.py @@ -77,6 +77,7 @@ from nncf.config.schemata.defaults import QUANTIZE_OUTPUTS from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic +from nncf.parameters import StripFormat from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS from nncf.torch.algo_selector import ZeroCompressionLoss from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder @@ -1478,10 +1479,12 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics: nncf_stats.register("quantization", stats) return nncf_stats - def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork: + def strip_model( + self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> NNCFNetwork: if do_copy: model = copy_model(model) - model = strip_quantized_model(model) + model = strip_quantized_model(model, strip_format) return model diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 0a60a5f2e77..103d06db292 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -20,8 +20,7 @@ from nncf.common.graph.transformations.commands import Command from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled -from nncf.experimental.torch2.commands import PT2InsertionCommand +from nncf.parameters import StripFormat from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -31,14 +30,12 @@ from nncf.torch.model_graph_manager import split_const_name from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.layers import AsymmetricLoraQuantizer from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricLoraQuantizer from nncf.torch.quantization.layers import SymmetricQuantizer from nncf.torch.quantization.quantize_functions import TuneRange @@ -184,21 +181,26 @@ def remove_disabled_quantizers(model: NNCFNetwork) -> NNCFNetwork: return model -def strip_quantized_model(model: NNCFNetwork): +def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripFormat.NATIVE): """ - Returns the model with as much custom NNCF additions as possible removed - while still preserving the functioning of the model object as a compressed model. + Removes auxiliary layers and operations added during the quantization process, + resulting in a clean quantized model ready for deployment. The functionality of the model object is still preserved + as a compressed model. :param model: Compressed model. + :param strip format: Describes the format in which model is saved after strip. :return: The modified NNCF network. """ model_layout = model.nncf.transformation_layout() transformations = model_layout.transformations - if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]): + if strip_format == StripFormat.DQ: model = replace_with_decompressors(model, transformations) - else: + elif strip_format == StripFormat.NATIVE: model = replace_quantizer_to_torch_native_module(model) model = remove_disabled_quantizers(model) + else: + msg = f"Unsupported strip format: {strip_format}" + raise nncf.ParameterNotSupportedError(msg) return model @@ -226,8 +228,16 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command for command in transformations: quantizer = command.fn - if len(command.target_points) > 1: + msg = None + if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)): + msg = f"Unexpected compression module on strip: {quantizer.__class__}" + elif quantizer._qspec.half_range or quantizer._qspec.narrow_range: + msg = "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False." + elif quantizer.num_bits not in [4, 8]: + msg = f"Unsupported number of bits {quantizer.num_bits} for the quantizer {quantizer}" + elif len(command.target_points) > 1: msg = "Command contains more than one target point!" + if msg: raise nncf.ValidationError(msg) tp = command.target_points[0] @@ -244,7 +254,7 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command qdq_weight = quantizer.quantize(original_weight) if hasattr(quantizer, "_lspec"): - # Special reshape for LoRA-grouped output + # Reshape for group-wise quantization, implemented for classes with lora spec only qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) qdq_weight = qdq_weight.to(original_dtype) @@ -283,8 +293,7 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command result_shape=original_shape, result_dtype=original_dtype, ) - - elif isinstance(quantizer, SymmetricQuantizer): + else: integer_dtype = torch.int8 scale = quantizer.scale / abs(quantizer.level_low) @@ -320,25 +329,13 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command if id(param) == id(original_weight): setattr(consumer_module, name, compressed_parameter) - if is_experimental_torch_tracing_enabled(): - transformation_layout.register( - PT2InsertionCommand( - [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") - ) - ], - decompressor, - ) - ) - else: - decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" - transformation_layout.register( - PTSharedFnInsertionCommand( - [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], - decompressor, - decompressor_name, - ) + decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" + transformation_layout.register( + PTSharedFnInsertionCommand( + [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], + decompressor, + decompressor_name, ) + ) return PTModelTransformer(model).transform(transformation_layout) diff --git a/nncf/torch/sparsity/base_algo.py b/nncf/torch/sparsity/base_algo.py index 3d6eed86fdf..641c78b7420 100644 --- a/nncf/torch/sparsity/base_algo.py +++ b/nncf/torch/sparsity/base_algo.py @@ -28,6 +28,7 @@ from nncf.common.sparsity.schedulers import SparsityScheduler from nncf.common.utils.api_marker import api from nncf.common.utils.backend import copy_model +from nncf.parameters import StripFormat from nncf.torch.algo_selector import ZeroCompressionLoss from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder from nncf.torch.compression_method_api import PTCompressionAlgorithmController @@ -128,7 +129,9 @@ def disable_scheduler(self): def compression_stage(self) -> CompressionStage: return CompressionStage.FULLY_COMPRESSED - def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork: + def strip_model( + self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE + ) -> NNCFNetwork: if do_copy: model = copy_model(model) diff --git a/nncf/torch/strip.py b/nncf/torch/strip.py index c27ac16fd91..9cf3b484c18 100644 --- a/nncf/torch/strip.py +++ b/nncf/torch/strip.py @@ -10,16 +10,18 @@ # limitations under the License. +from nncf.parameters import StripFormat from nncf.torch.nncf_network import NNCFNetwork -def strip(model: NNCFNetwork, do_copy: bool = True) -> NNCFNetwork: +def strip(model: NNCFNetwork, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> NNCFNetwork: """ - Returns the model object with as much custom NNCF additions as possible removed - while still preserving the functioning of the model object as a compressed model. + Removes auxiliary layers and operations added during the compression process, resulting in a clean + model ready for deployment. The functionality of the model object is still preserved as a compressed model. :param do_copy: If True (default), will return a copy of the currently associated model object. If False, - will return the currently associated model object "stripped" in-place. + will return the currently associated model object "stripped" in-place. + :param strip format: Describes the format in which model is saved after strip. :return: The stripped model. """ - return model.nncf.strip(do_copy) + return model.nncf.strip(do_copy, strip_format) diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py index fc44b29902d..4af2705b3ae 100644 --- a/tests/torch/ptq/test_fq_lora.py +++ b/tests/torch/ptq/test_fq_lora.py @@ -19,6 +19,7 @@ from transformers import AutoTokenizer import nncf +from nncf.parameters import StripFormat from nncf.torch.quantization.layers import AsymmetricQuantizer as AQ from nncf.torch.quantization.layers import LoraMixin from nncf.torch.quantization.layers import SymmetricQuantizer as SQ @@ -123,20 +124,21 @@ def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num assert first_loss > 8 assert float(loss) < 1 - tuned_output = generate_control_output(model, tokenizer) + with torch.no_grad(): + tuned_output = generate_control_output(model, tokenizer) - # Workaround till export from the optimum would be fixed - CVS-164159 - model = model.to(torch.float32) + # Workaround till export from the optimum would be fixed - CVS-164159 + model = model.to(torch.float32) - model = nncf.strip(model) - stripped_output = generate_control_output(model, tokenizer) + model = nncf.strip(model, strip_format=StripFormat.DQ) + stripped_output = generate_control_output(model, tokenizer) - model = get_ov_model(model, tmp_path) - stripped_ov_output = generate_control_output(model, tokenizer) + model = get_ov_model(model, tmp_path) + stripped_ov_output = generate_control_output(model, tokenizer) - vm = ValidationMock() - tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output) - tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output) + vm = ValidationMock() + tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output) + tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output) - assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=0.01) - assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=0.01) + assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=0.01) + assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=0.01) diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 89e97fd3494..98005664d20 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -22,6 +22,8 @@ from nncf.common.quantization.quantizers import get_num_levels from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.config import NNCFConfig +from nncf.parameters import CompressWeightsMode +from nncf.parameters import StripFormat from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec @@ -331,36 +333,36 @@ def test_nncf_strip_api(strip_type, do_copy): @pytest.mark.parametrize( ("mode", "torch_dtype", "atol"), ( - (nncf.CompressWeightsMode.INT4_ASYM, torch.float32, 0.0005), - (nncf.CompressWeightsMode.INT4_ASYM, torch.float16, 0.0005), - (nncf.CompressWeightsMode.INT4_ASYM, torch.bfloat16, 0.01), - (nncf.CompressWeightsMode.INT4_SYM, torch.float32, 0.0005), - (nncf.CompressWeightsMode.INT4_SYM, torch.float16, 0.0005), - (nncf.CompressWeightsMode.INT4_SYM, torch.bfloat16, 0.01), + (CompressWeightsMode.INT4_ASYM, torch.float32, 5e-4), + (CompressWeightsMode.INT4_ASYM, torch.float16, 5e-4), + (CompressWeightsMode.INT4_ASYM, torch.bfloat16, 1e-2), + (CompressWeightsMode.INT4_SYM, torch.float32, 5e-4), + (CompressWeightsMode.INT4_SYM, torch.float16, 5e-4), + (CompressWeightsMode.INT4_SYM, torch.bfloat16, 1e-2), + (CompressWeightsMode.INT8_SYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise + (CompressWeightsMode.INT8_ASYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise ), ) def test_nncf_strip_lora_model(mode, torch_dtype, atol): input_shape = [1, 16] model = LinearModel(input_shape=input_shape) model = model.to(torch_dtype) - with torch.no_grad(): - example = torch.ones(input_shape).to(torch_dtype) - dataset = [example] - - compressed_model = nncf.compress_weights( - model, - ratio=1, - group_size=4, - mode=mode, - backup_mode=None, - dataset=nncf.Dataset(dataset), - all_layers=True, - compression_format=nncf.CompressionFormat.FQ_LORA, - ) + example = torch.ones(input_shape).to(torch_dtype) + dataset = [example] + + compression_kwargs = dict( + mode=mode, + dataset=nncf.Dataset(dataset), + compression_format=nncf.CompressionFormat.FQ_LORA, + ) + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + compression_kwargs.update(dict(ratio=1, group_size=4, all_layers=True)) + compressed_model = nncf.compress_weights(model, **compression_kwargs) + with torch.no_grad(): compressed_output = compressed_model(example) - strip_compressed_model = nncf.strip(compressed_model, do_copy=True) + strip_compressed_model = nncf.strip(compressed_model, do_copy=True, strip_format=StripFormat.DQ) stripped_output = strip_compressed_model(example) assert torch.allclose(compressed_output, stripped_output, atol=atol) From 00d1ced8bad506a065fec332c29e9bebb5248159 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Mon, 17 Mar 2025 21:49:48 +0100 Subject: [PATCH 3/8] increased thresholds due to torch.compile execution --- tests/torch/ptq/test_fq_lora.py | 8 ++++++-- tests/torch/quantization/test_strip.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/torch/ptq/test_fq_lora.py b/tests/torch/ptq/test_fq_lora.py index 4af2705b3ae..e536adcfd08 100644 --- a/tests/torch/ptq/test_fq_lora.py +++ b/tests/torch/ptq/test_fq_lora.py @@ -124,6 +124,9 @@ def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num assert first_loss > 8 assert float(loss) < 1 + if "awq" in compression_kwargs: + return # Skip test for strip for awq + se initialization. Cases with data-free methods are enough. + with torch.no_grad(): tuned_output = generate_control_output(model, tokenizer) @@ -140,5 +143,6 @@ def test_fq_lora_tuning(tmp_path, mode, backup_mode, compression_kwargs, ref_num tuned_vs_stripped = vm.calculate_similarity(tuned_output, stripped_output) tuned_vs_stripped_ov = vm.calculate_similarity(tuned_output, stripped_ov_output) - assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=0.01) - assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=0.01) + atol = 0.03 if mode == nncf.CompressWeightsMode.INT4_SYM else 0.01 # torch.compile introduces bigger diff + assert torch.allclose(tuned_vs_stripped, vm.validation_ref, atol=atol) + assert torch.allclose(tuned_vs_stripped_ov, vm.validation_ref, atol=atol) diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 98005664d20..1218d421246 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -337,7 +337,7 @@ def test_nncf_strip_api(strip_type, do_copy): (CompressWeightsMode.INT4_ASYM, torch.float16, 5e-4), (CompressWeightsMode.INT4_ASYM, torch.bfloat16, 1e-2), (CompressWeightsMode.INT4_SYM, torch.float32, 5e-4), - (CompressWeightsMode.INT4_SYM, torch.float16, 5e-4), + (CompressWeightsMode.INT4_SYM, torch.float16, 1e-3), # torch.compile introduces bigger diff for sym (CompressWeightsMode.INT4_SYM, torch.bfloat16, 1e-2), (CompressWeightsMode.INT8_SYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise (CompressWeightsMode.INT8_ASYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise From 802376f1a91e306147ffb8875e4e85781b62c0e5 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Tue, 18 Mar 2025 17:21:45 +0100 Subject: [PATCH 4/8] minor correction --- nncf/parameters.py | 2 +- nncf/torch/quantization/strip.py | 27 +++++++++++---------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/nncf/parameters.py b/nncf/parameters.py index 8ba084939cd..92b158fa9a6 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -126,7 +126,7 @@ class StripFormat(StrEnum): operations added during the compression process, resulting in a clean model ready for deployment. The functionality of the model object is still preserved as a compressed model. - :param NATIVE: Returns the model with as much custom NNCF additions as possible, + :param NATIVE: Returns the model with as much custom NNCF additions as possible. :param DQ: Replaces FakeQuantize operations with dequantization subgraph and compressed weights in low-bit precision using fake quantize parameters. This is the default format for deployment of models with compressed weights. diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 103d06db292..989d7d64f10 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -10,14 +10,11 @@ # limitations under the License. -from typing import List - import numpy as np import torch from torch.quantization.fake_quantize import FakeQuantize import nncf -from nncf.common.graph.transformations.commands import Command from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.parameters import StripFormat @@ -191,10 +188,8 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF :param strip format: Describes the format in which model is saved after strip. :return: The modified NNCF network. """ - model_layout = model.nncf.transformation_layout() - transformations = model_layout.transformations if strip_format == StripFormat.DQ: - model = replace_with_decompressors(model, transformations) + model = replace_with_decompressors(model) elif strip_format == StripFormat.NATIVE: model = replace_quantizer_to_torch_native_module(model) model = remove_disabled_quantizers(model) @@ -204,7 +199,7 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF return model -def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork: +def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: """ Performs transformation from fake quantize format (FQ) to dequantization one (DQ). The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value, @@ -222,21 +217,21 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command :return: The modified NNCF network. """ transformation_layout = TransformationLayout() + transformations = model.nncf.transformation_layout().transformations model = model.nncf.get_clean_shallow_copy() graph = model.nncf.get_graph() - for command in transformations: quantizer = command.fn - msg = None + msg = "" if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)): - msg = f"Unexpected compression module on strip: {quantizer.__class__}" - elif quantizer._qspec.half_range or quantizer._qspec.narrow_range: - msg = "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False." - elif quantizer.num_bits not in [4, 8]: - msg = f"Unsupported number of bits {quantizer.num_bits} for the quantizer {quantizer}" - elif len(command.target_points) > 1: - msg = "Command contains more than one target point!" + msg = f"Unexpected compression module on strip: {quantizer.__class__}.\n" + if quantizer._qspec.half_range or quantizer._qspec.narrow_range: + msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n" + if quantizer.num_bits not in [4, 8]: + msg += f"Unsupported number of bits {quantizer.num_bits} for the quantizer {quantizer}.\n" + if len(command.target_points) > 1: + msg += "Command contains more than one target point." if msg: raise nncf.ValidationError(msg) From dcd7c4c543a6364f77b97bc5ed5dc23e9911903b Mon Sep 17 00:00:00 2001 From: Nikolay Date: Wed, 19 Mar 2025 19:44:53 +0100 Subject: [PATCH 5/8] refactored strip --- nncf/torch/quantization/strip.py | 203 ++++++++++++++++++------------- 1 file changed, 120 insertions(+), 83 deletions(-) diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 989d7d64f10..2bcdca104c3 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -10,6 +10,8 @@ # limitations under the License. +from typing import Tuple + import numpy as np import torch from torch.quantization.fake_quantize import FakeQuantize @@ -18,10 +20,10 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.parameters import StripFormat -from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.model_graph_manager import get_module_by_name from nncf.torch.model_graph_manager import split_const_name @@ -29,6 +31,7 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import BaseWeightsDecompressor from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor @@ -199,6 +202,104 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF return model +def asym_fq_to_decompressor( + quantizer: AsymmetricQuantizer, weight: torch.Tensor +) -> Tuple[BaseWeightsDecompressor, torch.Tensor]: + """ + Converts an asymmetric quantizer and original weight tensor to a decompressor and quantized weight tensor. + + :param quantizer: The asymmetric quantizer instance. + :param weight: The weight tensor to be compressed and used in decompressor. + :return: The decompressor and quantized weight corresponding to the given quantizer and original weight. + """ + assert isinstance(quantizer, AsymmetricQuantizer) + weight_dtype = weight.dtype + weight_shape = weight.shape + eps = torch.finfo(weight_dtype).eps + qdq_weight = quantizer.quantize(weight) + if hasattr(quantizer, "_lspec"): + # Reshape for group-wise quantization, implemented for classes with lora spec only + qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) + qdq_weight = qdq_weight.to(weight_dtype) + + input_range_safe = abs(quantizer.input_range) + quantizer.eps + input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels) + + integer_dtype = torch.uint8 + + input_low = input_low.to(weight_dtype) + input_range = input_range.to(weight_dtype) + + scale = input_range / quantizer.level_high + scale = torch.where(torch.abs(scale) < eps, eps, scale) + scale = scale.to(weight_dtype) + + zero_point = quantizer.level_low - torch.round(input_low / scale) + zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high) + zero_point = zero_point.to(integer_dtype) + + q_weight = qdq_weight / scale + q_weight = q_weight + zero_point + q_weight = torch.round(q_weight) + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) + q_weight = q_weight.to(integer_dtype) + + if quantizer.num_bits == 8: + decompressor = INT8AsymmetricWeightsDecompressor(scale=scale, zero_point=zero_point, result_dtype=weight_dtype) + else: + decompressor = INT4AsymmetricWeightsDecompressor( + scale=scale, + zero_point=zero_point, + compressed_weight_shape=q_weight.shape, + result_shape=weight_shape, + result_dtype=weight_dtype, + ) + return decompressor, q_weight + + +def sym_fq_to_decompressor( + quantizer: SymmetricQuantizer, weight: torch.Tensor +) -> Tuple[BaseWeightsDecompressor, torch.Tensor]: + """ + Converts an asymmetric quantizer and original weight tensor to a decompressor and quantized weight tensor. + + :param quantizer: The asymmetric quantizer instance. + :param weight: The weight tensor to be compressed and used in decompressor. + :return: The decompressor and quantized weight corresponding to the given quantizer and original weight. + """ + assert isinstance(quantizer, SymmetricQuantizer) + weight_dtype = weight.dtype + weight_shape = weight.shape + eps = torch.finfo(weight_dtype).eps + qdq_weight = quantizer.quantize(weight) + if hasattr(quantizer, "_lspec"): + # Reshape for group-wise quantization, implemented for classes with lora spec only + qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) + qdq_weight = qdq_weight.to(weight_dtype) + + integer_dtype = torch.int8 + + scale = quantizer.scale / abs(quantizer.level_low) + scale = torch.where(torch.abs(scale) < eps, eps, scale) + scale = scale.to(weight_dtype) + + q_weight = qdq_weight / scale + q_weight = torch.round(q_weight) + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) + q_weight = q_weight.to(integer_dtype) + + if quantizer.num_bits == 8: + decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=weight_dtype) + else: + decompressor = INT4SymmetricWeightsDecompressor( + scale=scale, + compressed_weight_shape=q_weight.shape, + result_shape=weight_shape, + result_dtype=weight_dtype, + ) + return decompressor, q_weight + + def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: """ Performs transformation from fake quantize format (FQ) to dequantization one (DQ). @@ -222,10 +323,11 @@ def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: graph = model.nncf.get_graph() for command in transformations: quantizer = command.fn + if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)): + # strip is only applied to Fake Quantizers, skip all other modules, e.g. SQMultiply for AWQ + continue msg = "" - if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)): - msg = f"Unexpected compression module on strip: {quantizer.__class__}.\n" if quantizer._qspec.half_range or quantizer._qspec.narrow_range: msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n" if quantizer.num_bits not in [4, 8]: @@ -238,91 +340,26 @@ def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: tp = command.target_points[0] node_with_weight = graph.get_node_by_name(tp.target_node_name) weight_node = get_const_node(node_with_weight, tp.input_port_id, graph) + weight_name = weight_node.layer_attributes.name + weight = get_const_data(weight_node, model) - module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) - module = get_module_by_name(module_name, model) - original_weight = getattr(module, weight_attr_name) - - original_dtype = original_weight.dtype - original_shape = original_weight.shape - original_eps = torch.finfo(original_dtype).eps - - qdq_weight = quantizer.quantize(original_weight) - if hasattr(quantizer, "_lspec"): - # Reshape for group-wise quantization, implemented for classes with lora spec only - qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) - qdq_weight = qdq_weight.to(original_dtype) - - if isinstance(quantizer, AsymmetricQuantizer): - input_range_safe = abs(quantizer.input_range) + quantizer.eps - input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels) - - integer_dtype = torch.uint8 - - input_low = input_low.to(original_dtype) - input_range = input_range.to(original_dtype) - - scale = input_range / quantizer.level_high - scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) - scale = scale.to(original_dtype) - - zero_point = quantizer.level_low - torch.round(input_low / scale) - zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high) - zero_point = zero_point.to(integer_dtype) - - q_weight = qdq_weight / scale - q_weight = q_weight + zero_point - q_weight = torch.round(q_weight) - q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) - q_weight = q_weight.to(integer_dtype) - - if quantizer.num_bits == 8: - decompressor = INT8AsymmetricWeightsDecompressor( - scale=scale, zero_point=zero_point, result_dtype=original_dtype - ) - else: - decompressor = INT4AsymmetricWeightsDecompressor( - scale=scale, - zero_point=zero_point, - compressed_weight_shape=q_weight.shape, - result_shape=original_shape, - result_dtype=original_dtype, - ) - else: - integer_dtype = torch.int8 - - scale = quantizer.scale / abs(quantizer.level_low) - scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) - scale = scale.to(original_dtype) - - q_weight = qdq_weight / scale - q_weight = torch.round(q_weight) - q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) - q_weight = q_weight.to(integer_dtype) - - if quantizer.num_bits == 8: - decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype) - else: - decompressor = INT4SymmetricWeightsDecompressor( - scale=scale, - compressed_weight_shape=q_weight.shape, - result_shape=original_shape, - result_dtype=original_dtype, - ) + convert_fn = asym_fq_to_decompressor if isinstance(quantizer, AsymmetricQuantizer) else sym_fq_to_decompressor + decompressor, q_weight = convert_fn(quantizer, weight) packed_tensor = decompressor.pack_weight(q_weight) # sets compressed tensor - compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) - setattr(module, weight_attr_name, compressed_parameter) - - consumer_nodes = graph.get_next_nodes(weight_node) - if len(consumer_nodes) > 1: - for consumer_node in consumer_nodes: - consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name)) - for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False): - if id(param) == id(original_weight): - setattr(consumer_module, name, compressed_parameter) + # TODO:(AlexanderDokuchaev): update set_const_data + module_name, weight_attr_name = split_const_name(weight_name) + module = get_module_by_name(module_name, model) + weight = getattr(module, weight_attr_name) + + if not isinstance(weight, torch.nn.Parameter): + msg = f"Weight is not a torch.nn.Parameter in the model by name {weight_name}." + raise nncf.InternalError(msg) + + weight.requires_grad = False + weight.data = packed_tensor decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" transformation_layout.register( From 76f3c8ea980ee2f68d5b52741b9474cfc0dad485 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Wed, 19 Mar 2025 22:52:09 +0100 Subject: [PATCH 6/8] corrections --- nncf/torch/quantization/strip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nncf/torch/quantization/strip.py b/nncf/torch/quantization/strip.py index 2bcdca104c3..f8ae9d21f48 100644 --- a/nncf/torch/quantization/strip.py +++ b/nncf/torch/quantization/strip.py @@ -325,6 +325,7 @@ def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: quantizer = command.fn if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)): # strip is only applied to Fake Quantizers, skip all other modules, e.g. SQMultiply for AWQ + transformation_layout.register(command) continue msg = "" @@ -340,6 +341,9 @@ def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork: tp = command.target_points[0] node_with_weight = graph.get_node_by_name(tp.target_node_name) weight_node = get_const_node(node_with_weight, tp.input_port_id, graph) + if weight_node is None: + msg = "FQ is not assigned to weight. Strip to DQ format is not supported for FQ on activation." + raise nncf.UnsupportedModelError(msg) weight_name = weight_node.layer_attributes.name weight = get_const_data(weight_node, model) From c9eef2e1ad1ed4023bdddb347e18b78a16b28ea1 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Thu, 20 Mar 2025 11:46:58 +0100 Subject: [PATCH 7/8] Extended test for strip --- tests/torch/quantization/test_strip.py | 81 +++++++++++++++++++------- 1 file changed, 61 insertions(+), 20 deletions(-) diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 1218d421246..041cbb0bdfb 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -9,11 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Tuple import numpy as np import pytest import torch +from torch import nn from torch.quantization.fake_quantize import FakeQuantize import nncf @@ -26,6 +27,11 @@ from nncf.parameters import StripFormat from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor as INT4AsymDQ +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor as INT4SymDQ +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor as INT8AsymDQ +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor as INT8SymDQ from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import SymmetricQuantizer from nncf.torch.quantization.strip import convert_to_torch_fakequantizer @@ -330,26 +336,46 @@ def test_nncf_strip_api(strip_type, do_copy): assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize) +def check_compression_modules( + model_: nn.Module, + expected_module_type: ExtraCompressionModuleType, + not_expected_module_type: ExtraCompressionModuleType, + expected_class: Any, +) -> None: + """ + Checks if the given model has the expected compression module registered and not the unexpected one. + Also verifies that the compression module is of the expected class type. + + :param model_: The model to be checked, which should have an 'nncf' attribute with compression module methods. + :param expected_module_type: The type of the compression module that is expected to be registered. + :param not_expected_module_type: The type of the compression module that is not expected to be registered. + :param expected_class: The class type that the expected compression module should be an instance of. + """ + assert model_.nncf.is_compression_module_registered(expected_module_type) + assert not model_.nncf.is_compression_module_registered(not_expected_module_type) + compression_modules_dict = model_.nncf.get_compression_modules_by_type(expected_module_type) + assert len(compression_modules_dict) == 1 + compression_module = next(iter(compression_modules_dict.values())) + assert isinstance(compression_module, expected_class) + + @pytest.mark.parametrize( - ("mode", "torch_dtype", "atol"), + ("mode", "decompressor_class", "torch_dtype", "atol"), ( - (CompressWeightsMode.INT4_ASYM, torch.float32, 5e-4), - (CompressWeightsMode.INT4_ASYM, torch.float16, 5e-4), - (CompressWeightsMode.INT4_ASYM, torch.bfloat16, 1e-2), - (CompressWeightsMode.INT4_SYM, torch.float32, 5e-4), - (CompressWeightsMode.INT4_SYM, torch.float16, 1e-3), # torch.compile introduces bigger diff for sym - (CompressWeightsMode.INT4_SYM, torch.bfloat16, 1e-2), - (CompressWeightsMode.INT8_SYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise - (CompressWeightsMode.INT8_ASYM, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise + (CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float32, 5e-4), + (CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.float16, 5e-4), + (CompressWeightsMode.INT4_ASYM, INT4AsymDQ, torch.bfloat16, 1e-2), + (CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float32, 5e-4), + (CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.float16, 1e-3), # torch.compile introduces bigger diff for sym + (CompressWeightsMode.INT4_SYM, INT4SymDQ, torch.bfloat16, 1e-2), + (CompressWeightsMode.INT8_SYM, INT8SymDQ, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise + (CompressWeightsMode.INT8_ASYM, INT8AsymDQ, torch.bfloat16, 5e-2), # int8 uses per-channel vs int4 group-wise ), ) -def test_nncf_strip_lora_model(mode, torch_dtype, atol): +def test_nncf_strip_lora_model(mode, decompressor_class, torch_dtype, atol, mocker): input_shape = [1, 16] - model = LinearModel(input_shape=input_shape) - model = model.to(torch_dtype) - example = torch.ones(input_shape).to(torch_dtype) - dataset = [example] - + model = LinearModel(input_shape=input_shape).to(torch_dtype) + dataset = [torch.ones(input_shape).to(torch_dtype)] compression_kwargs = dict( mode=mode, dataset=nncf.Dataset(dataset), @@ -358,11 +384,26 @@ def test_nncf_strip_lora_model(mode, torch_dtype, atol): if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: compression_kwargs.update(dict(ratio=1, group_size=4, all_layers=True)) compressed_model = nncf.compress_weights(model, **compression_kwargs) + check_compression_modules( + compressed_model, + expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + expected_class=BaseQuantizer, + ) + assert compressed_model.linear.weight.dtype == torch_dtype + pack_weight_spy = mocker.spy(decompressor_class, "pack_weight") with torch.no_grad(): - compressed_output = compressed_model(example) - + compressed_output = compressed_model(dataset[0]) strip_compressed_model = nncf.strip(compressed_model, do_copy=True, strip_format=StripFormat.DQ) - stripped_output = strip_compressed_model(example) - + stripped_output = strip_compressed_model(dataset[0]) + + assert pack_weight_spy.call_count in [1, 2] # pack_weight for asym is called twice: for ZP and weight + assert strip_compressed_model.linear.weight.dtype in [torch.uint8, torch.int8] + check_compression_modules( + strip_compressed_model, + expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + expected_class=decompressor_class, + ) assert torch.allclose(compressed_output, stripped_output, atol=atol) From ec4cbc98160576b544bf1d2e8dc5a139d4697f0a Mon Sep 17 00:00:00 2001 From: Nikolay Date: Thu, 20 Mar 2025 13:56:32 +0100 Subject: [PATCH 8/8] fixed tests --- tests/torch/quantization/test_strip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/torch/quantization/test_strip.py b/tests/torch/quantization/test_strip.py index 041cbb0bdfb..a0aa8491a01 100644 --- a/tests/torch/quantization/test_strip.py +++ b/tests/torch/quantization/test_strip.py @@ -25,6 +25,7 @@ from nncf.config import NNCFConfig from nncf.parameters import CompressWeightsMode from nncf.parameters import StripFormat +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.quantization.layers import AsymmetricQuantizer from nncf.torch.quantization.layers import BaseQuantizer @@ -380,6 +381,7 @@ def test_nncf_strip_lora_model(mode, decompressor_class, torch_dtype, atol, mock mode=mode, dataset=nncf.Dataset(dataset), compression_format=nncf.CompressionFormat.FQ_LORA, + advanced_parameters=AdvancedCompressionParameters(lora_adapter_rank=1), ) if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: compression_kwargs.update(dict(ratio=1, group_size=4, all_layers=True))