Skip to content

Commit a8ba008

Browse files
[PT2] weight_compression (openvinotoolkit#3293)
### Changes Implement weight compression algorithms for experimental PT tracing ### Related tickets 152996
1 parent 73590b0 commit a8ba008

File tree

8 files changed

+569
-40
lines changed

8 files changed

+569
-40
lines changed

nncf/quantization/algorithms/weight_compression/torch_backend.py

+62-28
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Dict, Iterable, List, Optional, Tuple
12+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1313

1414
import torch
1515

@@ -23,6 +23,7 @@
2323
from nncf.common.graph.transformations.commands import TargetType
2424
from nncf.common.graph.transformations.layout import TransformationLayout
2525
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
26+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
2627
from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer
2728
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
2829
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
@@ -35,6 +36,9 @@
3536
from nncf.experimental.common.tensor_statistics.statistics import MeanMagnitudeTensorStatistic
3637
from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic
3738
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
39+
from nncf.experimental.torch2.commands import PT2InsertionCommand
40+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
41+
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer
3842
from nncf.parameters import CompressWeightsMode
3943
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
4044
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
@@ -46,7 +50,6 @@
4650
from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight
4751
from nncf.tensor import Tensor
4852
from nncf.tensor.definitions import TensorDataType
49-
from nncf.torch.dynamic_graph.scope import Scope
5053
from nncf.torch.graph import operator_metatypes as om
5154
from nncf.torch.graph.operator_metatypes import PTMulMetatype
5255
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
@@ -186,8 +189,14 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
186189
return activation_ports[0]
187190

188191
def get_weight(
189-
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
192+
self,
193+
node_with_weight: NNCFNode,
194+
weight_port_id: int,
195+
model: Union[GraphModelWrapper, torch.nn.Module],
196+
graph: NNCFGraph,
190197
) -> Tensor:
198+
if isinstance(model, GraphModelWrapper):
199+
model = model.model
191200
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
192201
weight_name = weight_node.layer_attributes.name
193202
weight = get_const_data(weight_node, model)
@@ -197,7 +206,11 @@ def get_weight(
197206
return Tensor(weight)
198207

199208
def get_weight_dtype(
200-
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
209+
self,
210+
node_with_weight: NNCFNode,
211+
weight_port_id: int,
212+
model: Union[GraphModelWrapper, torch.nn.Module],
213+
graph: NNCFGraph,
201214
) -> TensorDataType:
202215
return self.get_weight(node_with_weight, weight_port_id, model, graph).dtype
203216

@@ -209,7 +222,14 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
209222
def set_weight(
210223
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
211224
):
212-
update_parameter(node_with_weight.node_name, "weight", weight.data, model)
225+
if is_experimental_torch_tracing_enabled():
226+
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
227+
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
228+
module = get_module_by_name(module_name, model.model)
229+
weight_param = getattr(module, weight_attr_name)
230+
weight_param.data = weight.data
231+
else:
232+
update_parameter(node_with_weight.node_name, "weight", weight.data, model)
213233

214234
def insert_adapters(
215235
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
@@ -229,13 +249,19 @@ def filter_func(point: StatisticPoint) -> bool:
229249

230250
def transform_model(
231251
self,
232-
model: NNCFNetwork,
252+
model: Union[GraphModelWrapper, torch.nn.Module],
233253
graph: NNCFGraph,
234254
weight_compression_parameters: Iterable[WeightCompressionParameters],
235255
precomputed_scales: Dict[str, Tensor] = None,
236256
precomputed_zero_points: Dict[str, Tensor] = None,
237257
lora_correction_algo: LoraCorrectionAlgorithm = None,
238258
) -> NNCFNetwork:
259+
if isinstance(model, GraphModelWrapper):
260+
model_transformer = PT2ModelTransformer(model)
261+
model = model.model
262+
else:
263+
model_transformer = PTModelTransformer(model)
264+
239265
transformation_layout = TransformationLayout()
240266

241267
for wc_params in weight_compression_parameters:
@@ -291,38 +317,43 @@ def transform_model(
291317

292318
# sets compressed tensor
293319
# TODO:(AlexanderDokuchaev): update set_const_data
294-
compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False)
295320
module_name, weight_attr_name = split_const_name(weight_name)
296321
module = get_module_by_name(module_name, model)
297322
weight = getattr(module, weight_attr_name)
323+
298324
if not isinstance(weight, torch.nn.Parameter):
299325
msg = f"Weight is not a torch.nn.Parameter in the model by name {weight_name}."
300326
raise nncf.InternalError(msg)
301327

302-
setattr(module, weight_attr_name, compressed_parameter)
303-
304-
consumer_nodes = graph.get_next_nodes(weight_node)
305-
if len(consumer_nodes) > 1:
306-
for c_node in consumer_nodes:
307-
c_module = model.nncf.get_module_by_scope(Scope.from_str(c_node.layer_name))
308-
for name, param in c_module.named_parameters(recurse=False, remove_duplicate=False):
309-
if id(param) == id(weight):
310-
setattr(c_module, name, compressed_parameter)
311-
312-
# registry weight decompression module in the model
313-
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
314-
315-
# inserts the weight decompressor into the model as the post hook on the model weight
316-
transformation_layout.register(
317-
PTSharedFnInsertionCommand(
318-
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)],
319-
decompressor,
320-
decompressor_name,
328+
weight.requires_grad = False
329+
weight.data = packed_tensor
330+
331+
if is_experimental_torch_tracing_enabled():
332+
transformation_layout.register(
333+
PT2InsertionCommand(
334+
[
335+
PTTargetPoint(
336+
TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name.replace(".", ":")
337+
)
338+
],
339+
decompressor,
340+
)
341+
)
342+
else:
343+
# registry weight decompression module in the model
344+
decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}"
345+
346+
# inserts the weight decompressor into the model as the post hook on the model weight
347+
transformation_layout.register(
348+
PTSharedFnInsertionCommand(
349+
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=weight_node.node_name)],
350+
decompressor,
351+
decompressor_name,
352+
)
321353
)
322-
)
323354

324355
# apply transformations
325-
transformed_model = PTModelTransformer(model).transform(transformation_layout)
356+
transformed_model = model_transformer.transform(transformation_layout)
326357

327358
return transformed_model
328359

@@ -356,6 +387,9 @@ def scale_insertion_command(
356387

357388
sq_multiply = SQMultiply(scale.shape)
358389
sq_multiply.scale = scale
390+
391+
if is_experimental_torch_tracing_enabled():
392+
return PT2InsertionCommand(target_points, sq_multiply)
359393
scale_node_name = f"{source_node.node_name}/awq_mul"
360394
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
361395

nncf/quantization/quantize_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def compress_weights(
511511

512512
if backend == BackendType.TORCH:
513513
from nncf.torch.model_creation import is_wrapped_model
514-
from nncf.torch.model_creation import wrap_model
514+
from nncf.torch.nncf_network import NNCFNetwork
515515
from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl
516516

517517
if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]:
@@ -529,7 +529,7 @@ def compress_weights(
529529
raise nncf.ParameterNotSupportedError(msg)
530530

531531
if is_wrapped_model(model):
532-
if not model.nncf.trace_parameters:
532+
if isinstance(model, NNCFNetwork) and not model.nncf.trace_parameters:
533533
msg = (
534534
"Tracing capabilities with tracing parameters are required in the PyTorch model "
535535
"for nncf.compress_weights(). Please wrap the model using "
@@ -541,6 +541,8 @@ def compress_weights(
541541
msg = "Please provide a dataset of at least one element for PyTorch model tracing."
542542
raise nncf.ValidationError(msg)
543543
else:
544+
from nncf.torch.model_creation import wrap_model
545+
544546
example_input = next(iter(dataset.get_inference_data()))
545547
model = wrap_model(model, example_input=example_input, trace_parameters=True)
546548
if mode in (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM):

nncf/torch/model_creation.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from nncf.config.extractors import extract_algorithm_names
2828
from nncf.config.extractors import has_input_info_field
2929
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
30+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
3031
from nncf.telemetry import tracked_function
3132
from nncf.telemetry.events import NNCF_PT_CATEGORY
3233
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
@@ -337,7 +338,7 @@ def wrap_model(
337338
model: torch.nn.Module,
338339
example_input: Any,
339340
trace_parameters: bool = False,
340-
) -> NNCFNetwork:
341+
) -> Any:
341342
"""
342343
Wraps a PyTorch model to the NNCFNetwork class.
343344
@@ -348,8 +349,18 @@ def wrap_model(
348349
as an example input of a set of non keyword arguments, and a dict as an example input of a set
349350
of keywords arguments.
350351
:param trace_parameters: Whether to trace model parameters. Default is False.
351-
:return: A model wrapped by NNCFNetwork.
352+
:return: A model wrapped by NNCFNetwork or GraphModelWrapper if experimental PyTorch model tracing is enabled.
352353
"""
354+
if is_experimental_torch_tracing_enabled():
355+
if not trace_parameters:
356+
msg = "The 'trace_parameters=False' option is not supported in the experimental tracing mode."
357+
raise nncf.InternalError(msg)
358+
from nncf.experimental.torch2.function_hook import wrap_model
359+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
360+
361+
wrapped_model = GraphModelWrapper(wrap_model(model), example_input=example_input)
362+
return wrapped_model
363+
353364
if not isinstance(model, torch.nn.Module):
354365
msg = (
355366
f"The provided model type {type(model)} is incompatible. "
@@ -368,14 +379,16 @@ def wrap_model(
368379
return nncf_network
369380

370381

371-
def is_wrapped_model(model: torch.nn.Module) -> bool:
382+
def is_wrapped_model(model: Any) -> bool:
372383
"""
373-
Check that the model was wrapped by NNCFNetwork.
384+
Check that the model was wrapped by NNCFNetwork or GraphModelWrapper.
374385
375386
:param model: A model.
376387
:return: True if the model is wrapped, False otherwise.
377388
"""
378-
return isinstance(model, NNCFNetwork)
389+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
390+
391+
return isinstance(model, (NNCFNetwork, GraphModelWrapper))
379392

380393

381394
@tracked_function(

nncf/torch/quantization/quantize_model.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
# limitations under the License.
1111

1212
from copy import deepcopy
13-
from typing import Optional
13+
from typing import Optional, Union
1414

1515
import torch
1616

1717
import nncf
1818
from nncf.common.factory import NNCFGraphFactory
1919
from nncf.common.quantization.structs import QuantizationPreset
2020
from nncf.data import Dataset
21+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
2122
from nncf.parameters import BackupMode
2223
from nncf.parameters import CompressWeightsMode
2324
from nncf.parameters import ModelType
@@ -85,7 +86,7 @@ def quantize_impl(
8586

8687

8788
def compress_weights_impl(
88-
model: torch.nn.Module,
89+
model: Union[GraphModelWrapper, torch.nn.Module],
8990
dataset: Dataset,
9091
mode: CompressWeightsMode,
9192
ratio: float,
@@ -120,4 +121,8 @@ def compress_weights_impl(
120121
advanced_parameters,
121122
)
122123
graph = NNCFGraphFactory.create(model)
123-
return compression_algorithm.apply(model, graph, dataset=dataset)
124+
125+
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
126+
if isinstance(compressed_model, GraphModelWrapper):
127+
compressed_model = compressed_model.model
128+
return compressed_model

tests/post_training/pipelines/lm_weight_compression.py

+4
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ def _dump_model_fp32(self) -> None:
285285
self.model_hf.save_pretrained(self.fp32_model_dir)
286286
self.model_hf._save_config(self.fp32_model_dir)
287287
elif self.backend == BackendType.TORCH:
288+
_need_clean_dict = "forward" not in self.model_hf.__dict__
288289
export_from_model(self.model_hf, self.fp32_model_dir, stateful=False, compression_option="fp32")
290+
if _need_clean_dict and "forward" in self.model_hf.__dict__:
291+
# WA for experimental tracing, clean up overwritten forward (same as in class method)
292+
del self.model_hf.__dict__["forward"]
289293

290294
def _compress(self):
291295
"""

tests/torch/ptq/test_weights_compression.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,11 @@ def get_scale_estimation_ref():
517517

518518
@staticmethod
519519
def get_orig_weight(model: torch.nn.Module) -> Tensor:
520-
return Tensor(model.linear.weight)
520+
return Tensor(model.linear.weight.data.detach())
521521

522522
@staticmethod
523523
def get_decompressed_weight(compressed_model: torch.nn.Module, input: torch.Tensor) -> Tensor:
524-
weight = compressed_model.linear.weight
524+
weight = compressed_model.linear.weight.data.detach()
525525
unpacked_w = compressed_model.nncf.external_op.weights_decompressor_linear_weight(weight)
526526
return Tensor(unpacked_w)
527527

0 commit comments

Comments
 (0)