Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QAT Lora 4/N] Strip for LoRA modules #3348

Merged
merged 10 commits into from
Mar 20, 2025
1 change: 1 addition & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.parameters import ModelType as ModelType
from nncf.parameters import QuantizationMode as QuantizationMode
from nncf.parameters import SensitivityMetric as SensitivityMetric
from nncf.parameters import StripFormat as StripFormat
from nncf.parameters import TargetDevice as TargetDevice
from nncf.quantization import QuantizationPreset as QuantizationPreset
from nncf.quantization import compress_weights as compress_weights
Expand Down
15 changes: 10 additions & 5 deletions nncf/api/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.api_marker import api
from nncf.common.utils.backend import copy_model
from nncf.parameters import StripFormat

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -236,14 +237,17 @@ def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
need to keep track of statistics on each training batch/step/iteration.
"""

def strip_model(self, model: TModel, do_copy: bool = False) -> TModel:
def strip_model(
self, model: TModel, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> TModel:
"""
Strips auxiliary layers that were used for the model compression, as it's
only needed for training. The method is used before exporting the model
in the target format.

:param model: The compressed model.
:param do_copy: Modify copy of the model, defaults to False.
:param strip format: Describes the format in which model is saved after strip.
:return: The stripped model.
"""
if do_copy:
Expand All @@ -256,16 +260,17 @@ def prepare_for_export(self) -> None:
"""
self._model = self.strip_model(self._model)

def strip(self, do_copy: bool = True) -> TModel: # type: ignore[type-var]
def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: # type: ignore[type-var]
"""
Returns the model object with as much custom NNCF additions as possible removed
while still preserving the functioning of the model object as a compressed model.
Removes auxiliary layers and operations added during the compression process, resulting in a clean
model ready for deployment. The functionality of the model object is still preserved as a compressed model.

:param do_copy: If True (default), will return a copy of the currently associated model object. If False,
will return the currently associated model object "stripped" in-place.
:param strip format: Describes the format in which model is saved after strip.
:return: The stripped model.
"""
return self.strip_model(self.model, do_copy) # type: ignore
return self.strip_model(self.model, do_copy, strip_format) # type: ignore

@abstractmethod
def export_model(
Expand Down
5 changes: 3 additions & 2 deletions nncf/common/composite_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import copy_model
from nncf.common.utils.backend import get_backend
from nncf.parameters import StripFormat


class CompositeCompressionLoss(CompressionLoss):
Expand Down Expand Up @@ -276,12 +277,12 @@ def prepare_for_export(self) -> None:
stripped_model = ctrl.strip_model(stripped_model)
self._model = stripped_model

def strip(self, do_copy: bool = True) -> TModel: # type: ignore
def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel: # type: ignore
model = self.model
if do_copy:
model = copy_model(model)
for ctrl in self.child_ctrls:
model = ctrl.strip_model(model, do_copy=False)
model = ctrl.strip_model(model, do_copy=False, strip_format=strip_format)
return model # type: ignore

@property
Expand Down
14 changes: 8 additions & 6 deletions nncf/common/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.common.utils.api_marker import api
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.parameters import StripFormat
from nncf.telemetry.decorator import tracked_function
from nncf.telemetry.events import MODEL_BASED_CATEGORY
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
Expand All @@ -25,25 +26,26 @@

@api(canonical_alias="nncf.strip")
@tracked_function(category=MODEL_BASED_CATEGORY, extractors=[FunctionCallTelemetryExtractor("nncf.strip")])
def strip(model: TModel, do_copy: bool = True) -> TModel:
def strip(model: TModel, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> TModel:
"""
Returns the model object with as much custom NNCF additions as possible removed
while still preserving the functioning of the model object as a compressed model.
Removes auxiliary layers and operations added during the compression process, resulting in a clean
model ready for deployment. The functionality of the model object is still preserved as a compressed model.

:param model: The compressed model.
:param do_copy: If True (default), will return a copy of the currently associated model object. If False,
will return the currently associated model object "stripped" in-place.
:param strip format: Describes the format in which model is saved after strip.
:return: The stripped model.
"""
model_backend = get_backend(model)
if model_backend == BackendType.TORCH:
from nncf.torch.strip import strip as strip_pt

return strip_pt(model, do_copy) # type: ignore
return strip_pt(model, do_copy, strip_format) # type: ignore
elif model_backend == BackendType.TENSORFLOW:
from nncf.tensorflow.strip import strip as strip_tf

return strip_tf(model, do_copy) # type: ignore
return strip_tf(model, do_copy, strip_format) # type: ignore

msg = f"Method `strip` does not support for {model_backend.value} backend."
msg = f"Method `strip` does not support {model_backend.value} backend."
raise nncf.UnsupportedBackendError(msg)
5 changes: 4 additions & 1 deletion nncf/experimental/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nncf.experimental.tensorflow.quantization.init_range import RangeInitializerV2
from nncf.experimental.tensorflow.quantization.init_range import TFRangeInitParamsV2
from nncf.experimental.tensorflow.quantization.quantizers import create_quantizer
from nncf.parameters import StripFormat
from nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS
from nncf.tensorflow.graph.metatypes.tf_ops import TFOpWithWeightsMetatype
from nncf.tensorflow.graph.transformations.commands import TFInsertionCommand
Expand Down Expand Up @@ -353,7 +354,9 @@ def apply_to(self, model: NNCFNetwork) -> NNCFNetwork:


class QuantizationControllerV2(QuantizationController):
def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork:
def strip_model(
self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> NNCFNetwork:
if do_copy:
model = copy_model(model)
return model
Expand Down
17 changes: 17 additions & 0 deletions nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,23 @@ class CompressionFormat(StrEnum):
FQ_LORA = "fake_quantize_with_lora"


@api(canonical_alias="nncf.StripFormat")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should appear in the nncf API docs after the merge, please check everything is shown properly after the merge

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will check

class StripFormat(StrEnum):
"""
Describes the format in which model is saved after strip: operation that removes auxiliary layers and
operations added during the compression process, resulting in a clean model ready for deployment.
The functionality of the model object is still preserved as a compressed model.

:param NATIVE: Returns the model with as much custom NNCF additions as possible,
:param DQ: Replaces FakeQuantize operations with dequantization subgraph and compressed weights in low-bit
precision using fake quantize parameters. This is the default format for deployment of models with compressed
weights.
"""

NATIVE = "native"
DQ = "dequantize"


@api(canonical_alias="nncf.BackupMode")
class BackupMode(StrEnum):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def get_fq_insertion_command(
orig_weight_shape: Tuple[int, ...],
compression_format: CompressionFormat,
lora_adapter_rank: int,
is_all_8bit: bool,
) -> PTTransformationCommand:
"""
Creates a fake quantization insertion command for the given compressed weight.
Expand All @@ -291,9 +292,11 @@ def get_fq_insertion_command(
:param wc_params: Parameters for weight compression.
:param orig_weight_shape: The original shape of the weight tensor.
:param compression_format: The format of compression.
:param is_all_8bit: Flag indicating if all weights should be compressed to 8-bit.
:return: A PTTransformationCommand for inserting fake quantization to the model.
"""
compression_config = wc_params.compression_config
# default mapping for 4bit weight compression and FQ_LORA format, no need to add lora adapters for 8bit weight
mode_vs_schema_map = {
CompressWeightsMode.INT4_ASYM: QuantizationScheme.ASYMMETRIC_LORA,
CompressWeightsMode.INT4_SYM: QuantizationScheme.SYMMETRIC_LORA,
Expand All @@ -303,6 +306,9 @@ def get_fq_insertion_command(
if compression_format == CompressionFormat.FQ:
mode_vs_schema_map[CompressWeightsMode.INT4_ASYM] = QuantizationScheme.ASYMMETRIC
mode_vs_schema_map[CompressWeightsMode.INT4_SYM] = QuantizationScheme.SYMMETRIC
if is_all_8bit and compression_format == CompressionFormat.FQ_LORA:
mode_vs_schema_map[CompressWeightsMode.INT8_ASYM] = QuantizationScheme.ASYMMETRIC_LORA
mode_vs_schema_map[CompressWeightsMode.INT8_SYM] = QuantizationScheme.SYMMETRIC_LORA

schema = mode_vs_schema_map[compression_config.mode]

Expand Down Expand Up @@ -469,6 +475,7 @@ def transform_model(
model_transformer = PTModelTransformer(model)

transformation_layout = TransformationLayout()
is_all_8bit = all(wc_params.compression_config.num_bits == 8 for wc_params in weight_compression_parameters)
for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
if compression_config.mode in [
Expand Down Expand Up @@ -499,7 +506,7 @@ def transform_model(
else:
rank = advanced_parameters.lora_adapter_rank
command = self.get_fq_insertion_command(
compressed_weight, wc_params, weight.shape, compression_format, rank
compressed_weight, wc_params, weight.shape, compression_format, rank, is_all_8bit
)
transformation_layout.register(command)

Expand Down
3 changes: 2 additions & 1 deletion nncf/tensorflow/algorithm_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.backend import copy_model
from nncf.common.utils.registry import Registry
from nncf.parameters import StripFormat
from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from nncf.tensorflow.loss import TFZeroCompressionLoss

Expand Down Expand Up @@ -60,7 +61,7 @@ def scheduler(self) -> StubCompressionScheduler:
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
return NNCFStatistics()

def strip(self, do_copy: bool = True) -> tf.keras.Model:
def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> tf.keras.Model:
model = self.model
if do_copy:
model = copy_model(self.model)
Expand Down
5 changes: 4 additions & 1 deletion nncf/tensorflow/pruning/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nncf.config.schemata.defaults import PRUNE_DOWNSAMPLE_CONVS
from nncf.config.schemata.defaults import PRUNE_FIRST_CONV
from nncf.config.schemata.defaults import PRUNING_INIT
from nncf.parameters import StripFormat
from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from nncf.tensorflow.graph.converter import TFModelConverterFactory
from nncf.tensorflow.graph.metatypes.keras_layers import TFBatchNormalizationLayerMetatype
Expand Down Expand Up @@ -359,6 +360,8 @@ def _calculate_pruned_layers_summary(self) -> List[PrunedLayerSummary]:

return pruned_layers_summary

def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model:
def strip_model(
self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> tf.keras.Model:
# Transform model for pruning creates copy of the model.
return strip_model_from_masks(model, self._op_names)
5 changes: 4 additions & 1 deletion nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from nncf.config.schemata.defaults import QUANTIZE_INPUTS
from nncf.config.schemata.defaults import QUANTIZE_OUTPUTS
from nncf.config.schemata.defaults import TARGET_DEVICE
from nncf.parameters import StripFormat
from nncf.tensorflow.algorithm_selector import TF_COMPRESSION_ALGORITHMS
from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from nncf.tensorflow.graph.converter import TFModelConverter
Expand Down Expand Up @@ -753,7 +754,9 @@ def loss(self) -> CompressionLoss:
"""
return self._loss

def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model:
def strip_model(
self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> tf.keras.Model:
if do_copy:
model = copy_model(model)
apply_overflow_fix(model, self._op_names)
Expand Down
5 changes: 4 additions & 1 deletion nncf/tensorflow/sparsity/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from nncf.common.compression import BaseCompressionAlgorithmController
from nncf.common.sparsity.controller import SparsityController
from nncf.parameters import StripFormat
from nncf.tensorflow.graph.metatypes import keras_layers as layer_metatypes
from nncf.tensorflow.sparsity.utils import strip_model_from_masks

Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self, target_model, op_names):
super().__init__(target_model)
self._op_names = op_names

def strip_model(self, model: tf.keras.Model, do_copy: bool = False) -> tf.keras.Model:
def strip_model(
self, model: tf.keras.Model, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> tf.keras.Model:
# Transform model for sparsity creates copy of the model.
return strip_model_from_masks(model, self._op_names)
10 changes: 9 additions & 1 deletion nncf/tensorflow/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

import tensorflow as tf

import nncf
from nncf.common.utils.backend import copy_model
from nncf.parameters import StripFormat
from nncf.tensorflow.graph.model_transformer import TFModelTransformer
from nncf.tensorflow.graph.transformations.commands import TFOperationWithWeights
from nncf.tensorflow.graph.transformations.commands import TFRemovalCommand
Expand All @@ -28,15 +30,21 @@
from nncf.tensorflow.sparsity.utils import apply_mask


def strip(model: tf.keras.Model, do_copy: bool = True) -> tf.keras.Model:
def strip(
model: tf.keras.Model, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE
) -> tf.keras.Model:
"""
Implementation of the nncf.strip() function for the TF backend

:param model: The compressed model.
:param do_copy: If True (default), will return a copy of the currently associated model object. If False,
will return the currently associated model object "stripped" in-place.
:param strip format: Describes the format in which model is saved after strip.
:return: The stripped model.
"""
if strip_format != StripFormat.NATIVE:
msg = f"Tensorflow does not support for {strip_format} strip format."
raise nncf.UnsupportedBackendError(msg)
if not isinstance(model, tf.keras.Model):
return model

Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/algo_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.backend import copy_model
from nncf.common.utils.registry import Registry
from nncf.parameters import StripFormat
from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.compression_method_api import PTCompressionLoss
Expand Down Expand Up @@ -81,7 +82,7 @@ def scheduler(self) -> CompressionScheduler:
def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
return NNCFStatistics()

def strip(self, do_copy: bool = True) -> NNCFNetwork:
def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> NNCFNetwork:
model = self.model
if do_copy:
model = copy_model(self.model)
Expand Down
13 changes: 8 additions & 5 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nncf.common.insertion_point_graph import PostHookInsertionPoint
from nncf.common.insertion_point_graph import PreHookInsertionPoint
from nncf.common.utils.debug import is_debug
from nncf.parameters import StripFormat
from nncf.telemetry import tracked_function
from nncf.telemetry.events import NNCF_PT_CATEGORY
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
Expand Down Expand Up @@ -966,21 +967,23 @@ def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]:
def set_compression_controller(self, ctrl: CompressionAlgorithmController):
self.compression_controller = ctrl

def strip(self, do_copy: bool = True) -> "NNCFNetwork":
def strip(self, do_copy: bool = True, strip_format: StripFormat = StripFormat.NATIVE) -> "NNCFNetwork":
"""
Returns the model object with as much custom NNCF additions as possible removed
while still preserving the functioning of the model object as a compressed model.
Removes auxiliary layers and operations added during the compression process, resulting in a clean
model ready for deployment. The functionality of the model object is still preserved as a compressed model.

:param do_copy: If True (default), will return a copy of the currently associated model object. If False,
will return the currently associated model object "stripped" in-place.
:param strip format: Describes the format in which model is saved after strip.
:return: The stripped model.
"""
if self.compression_controller is None:
# PTQ algorithm does not set compressed controller
from nncf.torch.quantization.strip import strip_quantized_model

model = deepcopy(self._model_ref) if do_copy else self._model_ref
return strip_quantized_model(model)
return self.compression_controller.strip(do_copy)
return strip_quantized_model(model, strip_format=strip_format)
return self.compression_controller.strip(do_copy, strip_format=strip_format)

def get_reused_parameters(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion nncf/torch/pruning/filter_pruning/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from nncf.common.utils.debug import is_debug
from nncf.common.utils.os import safe_open
from nncf.config.extractors import extract_bn_adaptation_init_params
from nncf.parameters import StripFormat
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.graph.operator_metatypes import PTModuleConv1dMetatype
Expand Down Expand Up @@ -693,7 +694,9 @@ def _run_batchnorm_adaptation(self):
)
self._bn_adaptation.run(self.model)

def strip_model(self, model: NNCFNetwork, do_copy: bool = False) -> NNCFNetwork:
def strip_model(
self, model: NNCFNetwork, do_copy: bool = False, strip_format: StripFormat = StripFormat.NATIVE
) -> NNCFNetwork:
if do_copy:
model = copy_model(model)

Expand Down
Loading