|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
12 | 12 |
|
| 13 | +from typing import List |
| 14 | + |
13 | 15 | import numpy as np
|
14 | 16 | import torch
|
15 | 17 | from torch.quantization.fake_quantize import FakeQuantize
|
16 | 18 |
|
17 | 19 | import nncf
|
| 20 | +from nncf.common.graph.transformations.commands import Command |
| 21 | +from nncf.common.graph.transformations.commands import TargetType |
| 22 | +from nncf.common.graph.transformations.layout import TransformationLayout |
| 23 | +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled |
| 24 | +from nncf.experimental.torch2.commands import PT2InsertionCommand |
| 25 | +from nncf.torch.dynamic_graph.scope import Scope |
18 | 26 | from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
|
| 27 | +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand |
| 28 | +from nncf.torch.graph.transformations.commands import PTTargetPoint |
| 29 | +from nncf.torch.model_graph_manager import get_const_node |
| 30 | +from nncf.torch.model_graph_manager import get_module_by_name |
| 31 | +from nncf.torch.model_graph_manager import split_const_name |
| 32 | +from nncf.torch.model_transformer import PTModelTransformer |
19 | 33 | from nncf.torch.nncf_network import NNCFNetwork
|
| 34 | +from nncf.torch.quantization.layers import AsymmetricLoraQuantizer |
20 | 35 | from nncf.torch.quantization.layers import AsymmetricQuantizer
|
21 | 36 | from nncf.torch.quantization.layers import BaseQuantizer
|
| 37 | +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor |
| 38 | +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor |
| 39 | +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor |
| 40 | +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor |
| 41 | +from nncf.torch.quantization.layers import SymmetricLoraQuantizer |
22 | 42 | from nncf.torch.quantization.layers import SymmetricQuantizer
|
| 43 | +from nncf.torch.quantization.quantize_functions import TuneRange |
23 | 44 |
|
24 | 45 | SUPPORTED_NUM_BITS_FOR_STRIP_MODEL = [8]
|
25 | 46 |
|
@@ -171,6 +192,153 @@ def strip_quantized_model(model: NNCFNetwork):
|
171 | 192 | :param model: Compressed model.
|
172 | 193 | :return: The modified NNCF network.
|
173 | 194 | """
|
174 |
| - model = replace_quantizer_to_torch_native_module(model) |
175 |
| - model = remove_disabled_quantizers(model) |
| 195 | + model_layout = model.nncf.transformation_layout() |
| 196 | + transformations = model_layout.transformations |
| 197 | + if any([type(q.fn) in [AsymmetricLoraQuantizer, SymmetricLoraQuantizer] for q in transformations]): |
| 198 | + model = replace_with_decompressors(model, transformations) |
| 199 | + else: |
| 200 | + model = replace_quantizer_to_torch_native_module(model) |
| 201 | + model = remove_disabled_quantizers(model) |
176 | 202 | return model
|
| 203 | + |
| 204 | + |
| 205 | +def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork: |
| 206 | + """ |
| 207 | + Performs transformation from fake quantize format (FQ) to dequantization one (DQ). |
| 208 | + The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value, |
| 209 | + while the latter takes a quantized integer representation, dequantizes it, and outputs a floating-point result. |
| 210 | +
|
| 211 | + Mathematically, both methods lead to the same outcome, but due to differences in the order of operations and |
| 212 | + rounding errors, the actual results may differ. In particular, this error can occur for values |
| 213 | + that are located in the midpoint between two quantized values ("quants"). |
| 214 | +
|
| 215 | + The FQ format may round these values to one "quant", while the DQ format rounds them to another "quant". |
| 216 | + To avoid these issues, the compressed representation should be provided not by directly quantizing the input, |
| 217 | + but by quantizing a pre-processed, fake-quantized, floating-point representation. |
| 218 | +
|
| 219 | + :param model: Compressed model with Decompressors. |
| 220 | + :return: The modified NNCF network. |
| 221 | + """ |
| 222 | + transformation_layout = TransformationLayout() |
| 223 | + model = model.nncf.get_clean_shallow_copy() |
| 224 | + graph = model.nncf.get_graph() |
| 225 | + |
| 226 | + for command in transformations: |
| 227 | + quantizer = command.fn |
| 228 | + |
| 229 | + if len(command.target_points) > 1: |
| 230 | + msg = "Command contains more than one target point!" |
| 231 | + raise nncf.ValidationError(msg) |
| 232 | + |
| 233 | + tp = command.target_points[0] |
| 234 | + node_with_weight = graph.get_node_by_name(tp.target_node_name) |
| 235 | + weight_node = get_const_node(node_with_weight, tp.input_port_id, graph) |
| 236 | + |
| 237 | + module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name) |
| 238 | + module = get_module_by_name(module_name, model) |
| 239 | + original_weight = getattr(module, weight_attr_name) |
| 240 | + |
| 241 | + original_dtype = original_weight.dtype |
| 242 | + original_shape = original_weight.shape |
| 243 | + original_eps = torch.finfo(original_dtype).eps |
| 244 | + |
| 245 | + qdq_weight = quantizer.quantize(original_weight) |
| 246 | + if hasattr(quantizer, "_lspec"): |
| 247 | + # Special reshape for LoRA-grouped output |
| 248 | + qdq_weight = qdq_weight.reshape(quantizer._lspec.weight_shape) |
| 249 | + qdq_weight = qdq_weight.to(original_dtype) |
| 250 | + |
| 251 | + if isinstance(quantizer, AsymmetricQuantizer): |
| 252 | + input_range_safe = abs(quantizer.input_range) + quantizer.eps |
| 253 | + input_low, input_range = TuneRange.apply(quantizer.input_low, input_range_safe, quantizer.levels) |
| 254 | + |
| 255 | + integer_dtype = torch.uint8 |
| 256 | + |
| 257 | + input_low = input_low.to(original_dtype) |
| 258 | + input_range = input_range.to(original_dtype) |
| 259 | + |
| 260 | + scale = input_range / quantizer.level_high |
| 261 | + scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) |
| 262 | + scale = scale.to(original_dtype) |
| 263 | + |
| 264 | + zero_point = quantizer.level_low - torch.round(input_low / scale) |
| 265 | + zero_point = torch.clip(zero_point, quantizer.level_low, quantizer.level_high) |
| 266 | + zero_point = zero_point.to(integer_dtype) |
| 267 | + |
| 268 | + q_weight = qdq_weight / scale |
| 269 | + q_weight = q_weight + zero_point |
| 270 | + q_weight = torch.round(q_weight) |
| 271 | + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) |
| 272 | + q_weight = q_weight.to(integer_dtype) |
| 273 | + |
| 274 | + if quantizer.num_bits == 8: |
| 275 | + decompressor = INT8AsymmetricWeightsDecompressor( |
| 276 | + scale=scale, zero_point=zero_point, result_dtype=original_dtype |
| 277 | + ) |
| 278 | + else: |
| 279 | + decompressor = INT4AsymmetricWeightsDecompressor( |
| 280 | + scale=scale, |
| 281 | + zero_point=zero_point, |
| 282 | + compressed_weight_shape=q_weight.shape, |
| 283 | + result_shape=original_shape, |
| 284 | + result_dtype=original_dtype, |
| 285 | + ) |
| 286 | + |
| 287 | + elif isinstance(quantizer, SymmetricQuantizer): |
| 288 | + integer_dtype = torch.int8 |
| 289 | + |
| 290 | + scale = quantizer.scale / abs(quantizer.level_low) |
| 291 | + scale = torch.where(torch.abs(scale) < original_eps, original_eps, scale) |
| 292 | + scale = scale.to(original_dtype) |
| 293 | + |
| 294 | + q_weight = qdq_weight / scale |
| 295 | + q_weight = torch.round(q_weight) |
| 296 | + q_weight = torch.clip(q_weight, quantizer.level_low, quantizer.level_high) |
| 297 | + q_weight = q_weight.to(integer_dtype) |
| 298 | + |
| 299 | + if quantizer.num_bits == 8: |
| 300 | + decompressor = INT8SymmetricWeightsDecompressor(scale=scale, result_dtype=original_dtype) |
| 301 | + else: |
| 302 | + decompressor = INT4SymmetricWeightsDecompressor( |
| 303 | + scale=scale, |
| 304 | + compressed_weight_shape=q_weight.shape, |
| 305 | + result_shape=original_shape, |
| 306 | + result_dtype=original_dtype, |
| 307 | + ) |
| 308 | + |
| 309 | + packed_tensor = decompressor.pack_weight(q_weight) |
| 310 | + |
| 311 | + # sets compressed tensor |
| 312 | + compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) |
| 313 | + setattr(module, weight_attr_name, compressed_parameter) |
| 314 | + |
| 315 | + consumer_nodes = graph.get_next_nodes(weight_node) |
| 316 | + if len(consumer_nodes) > 1: |
| 317 | + for consumer_node in consumer_nodes: |
| 318 | + consumer_module = model.nncf.get_module_by_scope(Scope.from_str(consumer_node.layer_name)) |
| 319 | + for name, param in consumer_module.named_parameters(recurse=False, remove_duplicate=False): |
| 320 | + if id(param) == id(original_weight): |
| 321 | + setattr(consumer_module, name, compressed_parameter) |
| 322 | + |
| 323 | + if is_experimental_torch_tracing_enabled(): |
| 324 | + transformation_layout.register( |
| 325 | + PT2InsertionCommand( |
| 326 | + [ |
| 327 | + PTTargetPoint( |
| 328 | + TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":") |
| 329 | + ) |
| 330 | + ], |
| 331 | + decompressor, |
| 332 | + ) |
| 333 | + ) |
| 334 | + else: |
| 335 | + decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" |
| 336 | + transformation_layout.register( |
| 337 | + PTSharedFnInsertionCommand( |
| 338 | + [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)], |
| 339 | + decompressor, |
| 340 | + decompressor_name, |
| 341 | + ) |
| 342 | + ) |
| 343 | + |
| 344 | + return PTModelTransformer(model).transform(transformation_layout) |
0 commit comments