11
11
12
12
from typing import Any , Callable , Dict , List , Tuple
13
13
14
- import numpy as np
15
14
import torch
16
15
17
16
import nncf .torch .graph .operator_metatypes as om
21
20
from nncf .common .graph .transformations .commands import TargetType
22
21
from nncf .common .quantization .quantizer_propagation .structs import QuantizationTrait
23
22
from nncf .common .tensor_statistics .statistic_point import StatisticPoint
23
+ from nncf .experimental .common .check_feature import is_experimental_torch_tracing_enabled
24
24
from nncf .experimental .common .tensor_statistics .collectors import AbsMaxReducer
25
25
from nncf .experimental .common .tensor_statistics .collectors import MaxAggregator
26
26
from nncf .experimental .common .tensor_statistics .collectors import TensorCollector
27
+ from nncf .experimental .torch2 .commands import PT2ConstUpdateCommand
28
+ from nncf .experimental .torch2 .commands import PT2InsertionCommand
29
+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
27
30
from nncf .quantization .algorithms .smooth_quant .backend import SmoothQuantAlgoBackend
28
31
from nncf .tensor import Tensor
29
32
from nncf .torch .graph .transformations .command_creation import create_command_to_update_weight
@@ -119,6 +122,9 @@ def get_abs_max_channel_collector(
119
122
120
123
@staticmethod
121
124
def get_weight_value (node_with_weight : NNCFNode , model : NNCFNetwork , nncf_graph : NNCFGraph ) -> Tensor :
125
+ if isinstance (model , GraphModelWrapper ):
126
+ model = model .model
127
+
122
128
weight_node = get_const_node (node_with_weight , node_with_weight .metatype .weight_port_ids [0 ], nncf_graph )
123
129
if weight_node is None :
124
130
msg = f"{ node_with_weight } node has no weight node."
@@ -127,7 +133,12 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
127
133
return Tensor (weight_data )
128
134
129
135
@staticmethod
130
- def weight_update_command (node_with_weight : NNCFNode , weight_value : np .ndarray ) -> PTWeightUpdateCommand :
136
+ def weight_update_command (
137
+ node_with_weight : NNCFNode , nncf_graph : NNCFGraph , weight_value : torch .Tensor
138
+ ) -> PTWeightUpdateCommand :
139
+ if is_experimental_torch_tracing_enabled ():
140
+ weight_node = get_const_node (node_with_weight , node_with_weight .metatype .weight_port_ids [0 ], nncf_graph )
141
+ return PT2ConstUpdateCommand (weight_node , weight_value )
131
142
return create_command_to_update_weight (node_with_weight , weight_value )
132
143
133
144
@staticmethod
@@ -145,6 +156,9 @@ def scale_insertion_command(
145
156
146
157
sq_multiply = SQMultiply (scale_value .shape )
147
158
sq_multiply .scale = scale_value
159
+
160
+ if is_experimental_torch_tracing_enabled ():
161
+ return PT2InsertionCommand (target_points = target_points , hook_module = sq_multiply )
148
162
return PTSharedFnInsertionCommand (target_points , sq_multiply , scale_node_name )
149
163
150
164
@staticmethod
@@ -161,6 +175,10 @@ def get_weight_channel_axis(node: NNCFNode) -> int:
161
175
162
176
@staticmethod
163
177
def is_node_with_shared_weight (node : NNCFNode , nncf_graph : NNCFGraph ) -> bool :
178
+ if is_experimental_torch_tracing_enabled ():
179
+ weight_node = get_const_node (node , node .metatype .weight_port_ids [0 ], nncf_graph )
180
+ output_edges = nncf_graph .get_next_nodes (weight_node )
181
+ return len (output_edges ) > 1
164
182
return node .is_shared ()
165
183
166
184
@staticmethod
0 commit comments