Skip to content

Commit 5ad9bc4

Browse files
[PT2] MinMax (#3166)
### Changes Introduce TORCH2 backend MinMax algorithms for torch2 backend Add handle_torch_function for quantization function to trace it by torch_function ### Related tickets 152996 ### Tests [test install](https://github.com/openvinotoolkit/nncf/actions/runs/13014595451)
1 parent 6565033 commit 5ad9bc4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+14748
-125
lines changed

nncf/common/factory.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import TypeVar, cast
12+
from typing import Any, TypeVar, cast
1313

1414
import nncf
1515
from nncf.common.engine import Engine
@@ -20,6 +20,7 @@
2020
from nncf.common.utils.backend import BackendType
2121
from nncf.common.utils.backend import get_backend
2222
from nncf.data.dataset import Dataset
23+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
2324

2425
TModel = TypeVar("TModel")
2526

@@ -53,17 +54,22 @@ def create(model: TModel) -> NNCFGraph:
5354

5455
return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
5556
if model_backend == BackendType.TORCH:
57+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
5658
from nncf.torch.nncf_network import NNCFNetwork
5759

58-
return cast(NNCFNetwork, model).nncf.get_graph()
60+
if isinstance(model, GraphModelWrapper):
61+
return model.build_graph()
62+
if isinstance(model, NNCFNetwork):
63+
return model.nncf.get_graph()
64+
raise nncf.InternalError(f"Unexpected type of model {type(model)} for TORCH backend")
5965
raise nncf.UnsupportedBackendError(
60-
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
66+
f"Cannot create backend-specific graph because {model_backend.value} is not supported!"
6167
)
6268

6369

6470
class ModelTransformerFactory:
6571
@staticmethod
66-
def create(model: TModel, inplace: bool = False) -> ModelTransformer:
72+
def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]:
6773
"""
6874
Factory method to create backend-specific ModelTransformer instance based on the input model.
6975
@@ -84,11 +90,18 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
8490
from nncf.openvino.graph.model_transformer import OVModelTransformer
8591

8692
return OVModelTransformer(cast(Model, model), inplace=inplace)
87-
if model_backend == BackendType.TORCH:
93+
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
94+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
95+
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer
96+
97+
return PT2ModelTransformer(cast(GraphModelWrapper, model))
98+
99+
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
88100
from nncf.torch.model_transformer import PTModelTransformer
89101
from nncf.torch.nncf_network import NNCFNetwork
90102

91103
return PTModelTransformer(cast(NNCFNetwork, model))
104+
92105
if model_backend == BackendType.TORCH_FX:
93106
from torch.fx import GraphModule
94107

@@ -125,11 +138,16 @@ def create(model: TModel) -> Engine:
125138
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
126139
from torch.nn import Module
127140

141+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
128142
from nncf.torch.engine import PTEngine
129143

130-
return PTEngine(cast(Module, model))
144+
if isinstance(model, GraphModelWrapper):
145+
pt_model = model.model
146+
else:
147+
pt_model = cast(Module, model)
148+
return PTEngine(pt_model)
131149
raise nncf.UnsupportedBackendError(
132-
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
150+
f"Cannot create backend-specific engine because {model_backend.value} is not supported!"
133151
)
134152

135153

@@ -176,10 +194,14 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
176194
from nncf.openvino.statistics.aggregator import OVStatisticsAggregator
177195

178196
return OVStatisticsAggregator(dataset)
179-
if model_backend == BackendType.TORCH:
197+
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
180198
from nncf.torch.statistics.aggregator import PTStatisticsAggregator
181199

182200
return PTStatisticsAggregator(dataset)
201+
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
202+
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator
203+
204+
return PT2StatisticsAggregator(dataset)
183205
if model_backend == BackendType.TORCH_FX:
184206
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator
185207

nncf/common/graph/model_transformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import TypeVar
12+
from typing import Generic, TypeVar
1313

1414
from nncf.common.graph.transformations.layout import TransformationLayout
1515

1616
TModel = TypeVar("TModel")
1717

1818

19-
class ModelTransformer:
19+
class ModelTransformer(Generic[TModel]):
2020
"""
2121
Applies transformations to the model.
2222
"""
@@ -29,7 +29,7 @@ def __init__(self, model: TModel):
2929
"""
3030
self._model = model
3131

32-
def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore
32+
def transform(self, transformation_layout: TransformationLayout) -> TModel:
3333
"""
3434
Applies transformations to the model.
3535

nncf/common/graph/transformations/layout.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from typing import List
1313

14-
from nncf.common.graph.transformations.commands import TransformationCommand
14+
from nncf.common.graph.transformations.commands import Command
1515

1616

1717
class TransformationLayout:
@@ -27,13 +27,13 @@ def __init__(self) -> None:
2727
"""
2828
Initialize Transformation Layout.
2929
"""
30-
self._transformations: List[TransformationCommand] = []
30+
self._transformations: List[Command] = []
3131

3232
@property
33-
def transformations(self) -> List[TransformationCommand]:
33+
def transformations(self) -> List[Command]:
3434
return self._transformations
3535

36-
def register(self, transformation: TransformationCommand) -> None:
36+
def register(self, transformation: Command) -> None:
3737
"""
3838
Registers the transformation command in the transformation layout.
3939

nncf/common/quantization/quantizer_removal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,6 @@ def revert_operations_to_floating_point_precision(
178178
)
179179

180180
model_transformer = ModelTransformerFactory.create(quantized_model)
181-
transformed_model = model_transformer.transform(transformation_layout) # type: ignore[var-annotated]
181+
transformed_model = model_transformer.transform(transformation_layout)
182182

183183
return cast(TModel, transformed_model)

nncf/common/scopes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
import re
13-
from typing import Iterable, List, Optional, Sequence, Union
13+
from typing import Iterable, List, Optional, Union
1414

1515
import nncf
1616
from nncf.common.graph import NNCFGraph
@@ -52,8 +52,8 @@ def matches_any(tested_str: str, strs_to_match_to: Union[Iterable[str], str, Non
5252

5353
def should_consider_scope(
5454
serializable_id: Union[QuantizerId, NNCFNodeName],
55-
ignored_scopes: Optional[Sequence[str]],
56-
target_scopes: Optional[Sequence[str]] = None,
55+
ignored_scopes: Optional[Iterable[str]],
56+
target_scopes: Optional[Iterable[str]] = None,
5757
) -> bool:
5858
"""
5959
Used when an entity arising during compression has to be compared to an allowlist or a denylist of strings.

nncf/common/tensor_statistics/aggregator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from nncf.common.graph.transformations.layout import TransformationLayout
2323
from nncf.common.logging import nncf_logger
2424
from nncf.common.logging.track_progress import track
25-
from nncf.common.tensor import NNCFTensor
2625
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
2726
from nncf.common.tensor_statistics.statistics_serializer import dump_statistics
2827
from nncf.common.tensor_statistics.statistics_serializer import load_statistics
2928
from nncf.common.utils.backend import BackendType
3029
from nncf.data.dataset import Dataset
3130
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic
31+
from nncf.tensor import Tensor
3232

3333
TensorType = TypeVar("TensorType")
3434
TModel = TypeVar("TModel")
@@ -165,7 +165,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer)
165165
self.stat_subset_size = max(self.stat_subset_size, tensor_collector.num_samples)
166166

167167
@abstractmethod
168-
def _register_statistics(self, outputs: Dict[str, NNCFTensor], statistic_points: StatisticPointsContainer) -> None:
168+
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
169169
"""
170170
Process prepared raw model outputs and statistic points for the further usage.
171171
@@ -203,7 +203,7 @@ def _get_merged_statistic_points(
203203

204204
@staticmethod
205205
@abstractmethod
206-
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
206+
def _process_outputs(outputs: Any) -> Dict[str, Tensor]:
207207
"""
208208
Post-process model outputs for the further statistics collection.
209209

nncf/common/utils/backend.py

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Callable, TypeVar, cast
1414

1515
import nncf
16+
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
1617

1718
try:
1819
import openvino # type: ignore # noqa: F401
@@ -53,6 +54,11 @@ def is_torch_model(model: Any) -> bool:
5354
import torch
5455
import torch.fx
5556

57+
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
58+
59+
if is_experimental_torch_tracing_enabled():
60+
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)
61+
5662
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)
5763

5864

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
14+
15+
def is_experimental_torch_tracing_enabled() -> bool:
16+
"""
17+
Checks if experimental torch tracing is enabled by environment variable NNCF_EXPERIMENTAL_TORCH_TRACING.
18+
19+
:return: True if experimental torch tracing is enabled, False otherwise.
20+
"""
21+
return os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is not None

nncf/experimental/torch2/commands.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import List, Optional
13+
14+
from torch import nn
15+
16+
from nncf.common.graph.transformations.commands import Command
17+
from nncf.common.graph.transformations.commands import TransformationType
18+
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
19+
from nncf.torch.graph.transformations.commands import PTTargetPoint
20+
21+
22+
class PT2InsertionCommand(Command):
23+
"""
24+
Insertion operation to the models.
25+
"""
26+
27+
def __init__(
28+
self,
29+
target_points: List[PTTargetPoint],
30+
hook_module: nn.Module,
31+
*,
32+
handle_storage: Optional[List[RemovableHookHandle]] = None,
33+
):
34+
"""
35+
:param target_points: The list of target points for the command.
36+
:param hook_module: The hook module for the command that will be inserted into the model
37+
to execute at the target points.
38+
:param handle_storage: The handle storage for the command to collect RemovableHookHandle. Defaults to None.
39+
"""
40+
super().__init__(TransformationType.INSERT)
41+
self.target_points = target_points
42+
self.hook_module = hook_module
43+
self.handle_storage = handle_storage

nncf/experimental/torch2/function_hook/handle_inner_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
"""
13-
This module implements selected functions from the `torch` module, excluding the `hand_function` mechanism.
13+
This module implements selected functions from the `torch` module, excluding the `handle_torch_function` function.
1414
1515
It processes inner functions to handle exception hooks and graph analysis. The implementation is designed
1616
to support custom handling of inner function exceptions for specific functions.

nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import nncf
2121
import nncf.torch.graph.operator_metatypes as om
22-
from nncf.common.graph.graph import NNCFGraph
2322
from nncf.common.graph.graph import NNCFNode
2423
from nncf.common.graph.layer_attributes import BaseLayerAttributes
2524
from nncf.common.graph.layer_attributes import Dtype
@@ -30,6 +29,7 @@
3029
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
3130
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
3231
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
32+
from nncf.torch.graph.graph import PTNNCFGraph
3333

3434

3535
def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
@@ -159,14 +159,14 @@ def get_layer_attributes(
159159
return None
160160

161161

162-
def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
162+
def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph:
163163
"""
164-
Converts a graph to an NNCFGraph.
164+
Converts a graph to an PTNNCFGraph.
165165
166166
:param nx_graph: The graph to convert.
167167
:return: The converted NNCFGraph.
168168
"""
169-
nncf_graph = NNCFGraph()
169+
nncf_graph = PTNNCFGraph()
170170

171171
map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
172172
for node, data in nx_graph.nodes(data=True):
@@ -178,10 +178,11 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
178178
meta_type = get_meta_type(node_type, meta)
179179
layer_attributes = get_layer_attributes(nx_graph, node, meta)
180180
nncf_node = nncf_graph.add_nncf_node(
181+
layer_attributes=layer_attributes,
182+
layer_name=node_name,
183+
node_metatype=meta_type,
181184
node_name=node_name,
182185
node_type=node_type,
183-
node_metatype=meta_type,
184-
layer_attributes=layer_attributes,
185186
)
186187
map_nx_node_to_nncf_node[node] = nncf_node
187188

@@ -207,7 +208,7 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
207208
return nncf_graph
208209

209210

210-
def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
211+
def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> PTNNCFGraph:
211212
"""
212213
Builds an NNCF graph from the given PyTorch model.
213214
@@ -218,3 +219,35 @@ def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
218219
"""
219220
graph = build_graph(model, *args, **kwargs)
220221
return convert_to_nncf_graph(graph)
222+
223+
224+
class GraphModelWrapper:
225+
"""
226+
A class that wraps a PyTorch model with examples inputs and provides an interface
227+
to build a computational graph of the model.
228+
229+
:param model: The PyTorch model to be wrapped.
230+
:param example_input: A tuple of example input for the model.
231+
"""
232+
233+
def __init__(self, model: nn.Module, example_input: Any) -> None:
234+
"""
235+
Initialize the GraphModelWrapper.
236+
"""
237+
self.model = model
238+
self.example_input = example_input
239+
240+
def build_graph(self) -> PTNNCFGraph:
241+
"""
242+
Constructs a computational graph of the given model.
243+
244+
This function builds a directed graph `PTNNCFGraph` representing the operations
245+
and data flow within the model by leveraging hooks by using GraphBuilderMode.
246+
247+
:return: A PTNNCFGraph where nodes represent operations of model.
248+
"""
249+
if isinstance(self.example_input, dict):
250+
return build_nncf_graph(self.model, **self.example_input)
251+
if isinstance(self.example_input, tuple):
252+
return build_nncf_graph(self.model, *self.example_input)
253+
return build_nncf_graph(self.model, self.example_input)

0 commit comments

Comments
 (0)