9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import TypeVar , cast
12
+ from typing import Any , TypeVar , cast
13
13
14
14
import nncf
15
15
from nncf .common .engine import Engine
20
20
from nncf .common .utils .backend import BackendType
21
21
from nncf .common .utils .backend import get_backend
22
22
from nncf .data .dataset import Dataset
23
+ from nncf .experimental .common .check_feature import is_experimental_torch_tracing_enabled
23
24
24
25
TModel = TypeVar ("TModel" )
25
26
@@ -53,17 +54,22 @@ def create(model: TModel) -> NNCFGraph:
53
54
54
55
return FXGraphConverter .create_nncf_graph (cast (GraphModule , model ))
55
56
if model_backend == BackendType .TORCH :
57
+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
56
58
from nncf .torch .nncf_network import NNCFNetwork
57
59
58
- return cast (NNCFNetwork , model ).nncf .get_graph ()
60
+ if isinstance (model , GraphModelWrapper ):
61
+ return model .build_graph ()
62
+ if isinstance (model , NNCFNetwork ):
63
+ return model .nncf .get_graph ()
64
+ raise nncf .InternalError (f"Unexpected type of model { type (model )} for TORCH backend" )
59
65
raise nncf .UnsupportedBackendError (
60
- "Cannot create backend-specific graph because {} is not supported!" . format ( model_backend . value )
66
+ f "Cannot create backend-specific graph because { model_backend . value } is not supported!"
61
67
)
62
68
63
69
64
70
class ModelTransformerFactory :
65
71
@staticmethod
66
- def create (model : TModel , inplace : bool = False ) -> ModelTransformer :
72
+ def create (model : TModel , inplace : bool = False ) -> ModelTransformer [ Any ] :
67
73
"""
68
74
Factory method to create backend-specific ModelTransformer instance based on the input model.
69
75
@@ -84,11 +90,18 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
84
90
from nncf .openvino .graph .model_transformer import OVModelTransformer
85
91
86
92
return OVModelTransformer (cast (Model , model ), inplace = inplace )
87
- if model_backend == BackendType .TORCH :
93
+ if model_backend == BackendType .TORCH and is_experimental_torch_tracing_enabled ():
94
+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
95
+ from nncf .experimental .torch2 .model_transformer import PT2ModelTransformer
96
+
97
+ return PT2ModelTransformer (cast (GraphModelWrapper , model ))
98
+
99
+ if model_backend == BackendType .TORCH and not is_experimental_torch_tracing_enabled ():
88
100
from nncf .torch .model_transformer import PTModelTransformer
89
101
from nncf .torch .nncf_network import NNCFNetwork
90
102
91
103
return PTModelTransformer (cast (NNCFNetwork , model ))
104
+
92
105
if model_backend == BackendType .TORCH_FX :
93
106
from torch .fx import GraphModule
94
107
@@ -125,11 +138,16 @@ def create(model: TModel) -> Engine:
125
138
if model_backend in (BackendType .TORCH , BackendType .TORCH_FX ):
126
139
from torch .nn import Module
127
140
141
+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
128
142
from nncf .torch .engine import PTEngine
129
143
130
- return PTEngine (cast (Module , model ))
144
+ if isinstance (model , GraphModelWrapper ):
145
+ pt_model = model .model
146
+ else :
147
+ pt_model = cast (Module , model )
148
+ return PTEngine (pt_model )
131
149
raise nncf .UnsupportedBackendError (
132
- "Cannot create backend-specific engine because {} is not supported!" . format ( model_backend . value )
150
+ f "Cannot create backend-specific engine because { model_backend . value } is not supported!"
133
151
)
134
152
135
153
@@ -176,10 +194,14 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
176
194
from nncf .openvino .statistics .aggregator import OVStatisticsAggregator
177
195
178
196
return OVStatisticsAggregator (dataset )
179
- if model_backend == BackendType .TORCH :
197
+ if model_backend == BackendType .TORCH and not is_experimental_torch_tracing_enabled () :
180
198
from nncf .torch .statistics .aggregator import PTStatisticsAggregator
181
199
182
200
return PTStatisticsAggregator (dataset )
201
+ if model_backend == BackendType .TORCH and is_experimental_torch_tracing_enabled ():
202
+ from nncf .experimental .torch2 .statistics .aggregator import PT2StatisticsAggregator
203
+
204
+ return PT2StatisticsAggregator (dataset )
183
205
if model_backend == BackendType .TORCH_FX :
184
206
from nncf .experimental .torch .fx .statistics .aggregator import FXStatisticsAggregator
185
207
0 commit comments