Skip to content

Commit 5951153

Browse files
[PT2] SmoothQuant (#3276)
### Changes Impanated SmoothQuant for experimental tracing Added PT2ConstUpdateCommand to update data ### Related tickets 152996
1 parent 033be9b commit 5951153

File tree

11 files changed

+202
-18
lines changed

11 files changed

+202
-18
lines changed

nncf/experimental/torch2/commands.py

+17
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from typing import List, Optional
1313

14+
import torch
1415
from torch import nn
1516

17+
from nncf.common.graph.graph import NNCFNode
1618
from nncf.common.graph.transformations.commands import Command
1719
from nncf.common.graph.transformations.commands import TransformationType
1820
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
@@ -41,3 +43,18 @@ def __init__(
4143
self.target_points = target_points
4244
self.hook_module = hook_module
4345
self.handle_storage = handle_storage
46+
47+
48+
class PT2ConstUpdateCommand(Command):
49+
"""
50+
Corrects weight value in the model based on the input value.
51+
"""
52+
53+
def __init__(self, node: NNCFNode, value: torch.Tensor):
54+
"""
55+
:param const_node: The node of the data in the model.
56+
:param value: The new value of the constant.
57+
"""
58+
super().__init__(TransformationType.CHANGE)
59+
self.node = node
60+
self.value = value

nncf/experimental/torch2/model_transformer.py

+21
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from nncf.common.graph.transformations.commands import Command
1919
from nncf.common.graph.transformations.commands import TargetType
2020
from nncf.common.graph.transformations.layout import TransformationLayout
21+
from nncf.experimental.torch2.commands import PT2ConstUpdateCommand
2122
from nncf.experimental.torch2.commands import PT2InsertionCommand
2223
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
2324
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
2425
from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook
2526
from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook
2627
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
2728
from nncf.torch.graph.transformations.commands import PTTargetPoint
29+
from nncf.torch.model_graph_manager import set_const_data
2830
from nncf.torch.model_graph_manager import update_fused_bias
2931

3032
TRANSFORMATION_PAIRS = Tuple[Tuple[Type[Any], Callable[[GraphModelWrapper, List[Any]], GraphModelWrapper]], ...]
@@ -41,6 +43,7 @@ def __init__(self, model: GraphModelWrapper):
4143
self._command_transformation_ordered_pairs: TRANSFORMATION_PAIRS = (
4244
(PT2InsertionCommand, self._apply_insertion_transformations),
4345
(PTBiasCorrectionCommand, self._apply_bias_correction_transformations),
46+
(PT2ConstUpdateCommand, self._apply_const_update_transformations),
4447
)
4548

4649
def transform(self, transformation_layout: TransformationLayout) -> GraphModelWrapper:
@@ -114,6 +117,24 @@ def _apply_bias_correction_transformations(
114117
)
115118
return wrapped_model
116119

120+
@staticmethod
121+
def _apply_const_update_transformations(
122+
wrapped_model: GraphModelWrapper, transformations: List[PT2ConstUpdateCommand]
123+
) -> GraphModelWrapper:
124+
"""
125+
Applies const data update transformations on the model.
126+
127+
:param model: Model to apply transformations.
128+
:param transformations: List of the const data update transformations.
129+
:return: Model with corrected bias.
130+
"""
131+
for transformation in transformations:
132+
node = transformation.node
133+
value = transformation.value
134+
set_const_data(value, node, wrapped_model.model)
135+
136+
return wrapped_model
137+
117138

118139
def insert_hook(model: nn.Module, hook: nn.Module, target_point: PTTargetPoint) -> RemovableHookHandle:
119140
"""

nncf/quantization/algorithms/min_max/torch_backend.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
2727
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
2828
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
29+
from nncf.experimental.torch2.commands import PT2InsertionCommand
2930
from nncf.parameters import ModelType
3031
from nncf.parameters import TargetDevice
3132
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
@@ -271,6 +272,9 @@ def create_quantizer_insertion_command(
271272
quantizer = PTMinMaxAlgoBackend._create_quantizer(
272273
quantizer_config, scale_shape, parameters, target_point.target_type
273274
)
275+
if is_experimental_torch_tracing_enabled():
276+
return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer)
277+
274278
return create_quantizer_insertion_command(target_point, quantizer)
275279

276280
@staticmethod
@@ -287,6 +291,8 @@ def create_unified_scales_quantizers_insertion_commands(
287291
quantizer = PTMinMaxAlgoBackend._create_quantizer(
288292
quantizer_config, scale_shape, parameters, target_points[0].target_type
289293
)
294+
if is_experimental_torch_tracing_enabled():
295+
return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)]
290296
return [create_shared_quantizer_insertion_command(target_points, quantizer)]
291297

292298
@staticmethod
@@ -312,7 +318,7 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
312318
# Batchnorm
313319
om.PTBatchNormMetatype,
314320
om.PTModuleBatchNormMetatype,
315-
# Сomparison operations
321+
# Comparison operations
316322
om.PTGreaterEqualMetatype,
317323
om.PTGreaterMetatype,
318324
om.PTLessEqualMetatype,

nncf/quantization/algorithms/smooth_quant/algorithm.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def apply(
160160
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model, graph)
161161
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value)
162162
scaled_weight = weight_value * weights_scale
163-
weight_update_command = self._backend_entity.weight_update_command(node_to_smooth, scaled_weight.data)
163+
weight_update_command = self._backend_entity.weight_update_command(
164+
node_to_smooth, graph, scaled_weight.data
165+
)
164166
transformation_layout.register(weight_update_command)
165167

166168
activations_by_output_id = {e.output_port_id: e for e in graph.get_output_edges(source_node)}

nncf/quantization/algorithms/smooth_quant/backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,14 @@ def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int, nn
127127
@staticmethod
128128
@abstractmethod
129129
def weight_update_command(
130-
node_with_weight: NNCFNode, weight_value: TTensor, weight_port_id: int
130+
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: TTensor
131131
) -> TransformationCommand:
132132
"""
133133
Returns command to update weights.
134134
135135
:param node_with_weight: NNCFNode instance.
136+
:param nncf_graph: NNCFGraph instance.
136137
:param weight_value: New weight value.
137-
:param weight_port_id: Weight port id.
138138
:return: TransformationCommand instance.
139139
"""
140140

nncf/quantization/algorithms/smooth_quant/openvino_backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def get_weight_tensor_port_id(node: NNCFNode) -> int:
100100
return const_ids[0]
101101

102102
@staticmethod
103-
def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand:
103+
def weight_update_command(
104+
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: np.ndarray
105+
) -> OVWeightUpdateCommand:
104106
weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight)
105107
return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id)
106108

nncf/quantization/algorithms/smooth_quant/torch_backend.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from typing import Any, Callable, Dict, List, Tuple
1313

14-
import numpy as np
1514
import torch
1615

1716
import nncf.torch.graph.operator_metatypes as om
@@ -21,9 +20,13 @@
2120
from nncf.common.graph.transformations.commands import TargetType
2221
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
2322
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
23+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
2424
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
2525
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
2626
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
27+
from nncf.experimental.torch2.commands import PT2ConstUpdateCommand
28+
from nncf.experimental.torch2.commands import PT2InsertionCommand
29+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
2730
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
2831
from nncf.tensor import Tensor
2932
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
@@ -119,6 +122,9 @@ def get_abs_max_channel_collector(
119122

120123
@staticmethod
121124
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph: NNCFGraph) -> Tensor:
125+
if isinstance(model, GraphModelWrapper):
126+
model = model.model
127+
122128
weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph)
123129
if weight_node is None:
124130
msg = f"{node_with_weight} node has no weight node."
@@ -127,7 +133,12 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
127133
return Tensor(weight_data)
128134

129135
@staticmethod
130-
def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> PTWeightUpdateCommand:
136+
def weight_update_command(
137+
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: torch.Tensor
138+
) -> PTWeightUpdateCommand:
139+
if is_experimental_torch_tracing_enabled():
140+
weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph)
141+
return PT2ConstUpdateCommand(weight_node, weight_value)
131142
return create_command_to_update_weight(node_with_weight, weight_value)
132143

133144
@staticmethod
@@ -145,6 +156,9 @@ def scale_insertion_command(
145156

146157
sq_multiply = SQMultiply(scale_value.shape)
147158
sq_multiply.scale = scale_value
159+
160+
if is_experimental_torch_tracing_enabled():
161+
return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply)
148162
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
149163

150164
@staticmethod
@@ -161,6 +175,10 @@ def get_weight_channel_axis(node: NNCFNode) -> int:
161175

162176
@staticmethod
163177
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
178+
if is_experimental_torch_tracing_enabled():
179+
weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph)
180+
output_edges = nncf_graph.get_next_nodes(weight_node)
181+
return len(output_edges) > 1
164182
return node.is_shared()
165183

166184
@staticmethod

nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def get_weight_value(node_with_weight: NNCFNode, model: torch.fx.GraphModule, nn
104104
return Tensor(weight_data.data)
105105

106106
@staticmethod
107-
def weight_update_command(node_with_weight: NNCFNode, weight_value: torch.Tensor) -> FXApplyTransformationCommand:
107+
def weight_update_command(
108+
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: torch.Tensor
109+
) -> FXApplyTransformationCommand:
108110
# TODO(dlyakhov): Use input port id depending on the node metatype/attributes.
109111
return FXApplyTransformationCommand(
110112
constant_update_transformation_builder(node_with_weight, weight_value.data, input_port_id=1)

nncf/torch/graph/transformations/command_creation.py

-7
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from nncf.common.graph.transformations.commands import TargetType
1919
from nncf.common.graph.transformations.commands import TransformationPriority
2020
from nncf.common.quantization.structs import NonWeightQuantizerId
21-
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
22-
from nncf.experimental.torch2.commands import PT2InsertionCommand
2321
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
2422
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
2523
from nncf.torch.graph.transformations.commands import PTInsertionCommand
@@ -56,9 +54,6 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW
5654
def create_quantizer_insertion_command(
5755
target_point: PTTargetPoint, quantizer: BaseQuantizer
5856
) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
59-
if is_experimental_torch_tracing_enabled():
60-
return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer)
61-
6257
quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
6358
storage_key = str(quantizer_id)
6459
return PTSharedFnInsertionCommand(
@@ -73,8 +68,6 @@ def create_quantizer_insertion_command(
7368
def create_shared_quantizer_insertion_command(
7469
target_points: List[PTTargetPoint], quantizer: BaseQuantizer
7570
) -> PTSharedFnInsertionCommand:
76-
if is_experimental_torch_tracing_enabled():
77-
return PT2InsertionCommand(target_points=target_points, hook_module=quantizer)
7871
quantizers_ids = []
7972
for target_point in target_points:
8073
quantizers_ids.append(NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id))

nncf/torch/layer_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# limitations under the License.
1111

1212
from abc import ABC
13-
from abc import abstractclassmethod
1413
from abc import abstractmethod
1514
from typing import Any, Dict
1615

@@ -29,7 +28,7 @@ class StatefullModuleInterface(ABC):
2928
Interface that should be implemented for every registered compression module to make it possible
3029
to save an compression modules state and create an compression module from the saved state.
3130
Config of the module should be json serializable, no python objects except
32-
standart (str, list and etc.) should be present in a compression module config.
31+
standard (str, list and etc.) should be present in a compression module config.
3332
Values for attributes with type torch.nn.Parameter
3433
is recovered from the model `state_dict`, so there is no need to keep them in the module config.
3534
Modules should avoid implementation of `__call__` method and use `forward` method instead,
@@ -44,7 +43,8 @@ def get_config(self) -> Dict[str, Any]:
4443
Returns the compression module config.
4544
"""
4645

47-
@abstractclassmethod
46+
@classmethod
47+
@abstractmethod
4848
def from_config(cls, state: Dict[str, Any]) -> object:
4949
"""
5050
Creates a compression module instance from the given config.

0 commit comments

Comments
 (0)