9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import Callable , Dict , Iterable , List , Optional , Tuple
12
+ from typing import Callable , Dict , Iterable , List , Optional , Tuple , Union
13
13
14
14
import torch
15
15
23
23
from nncf .common .graph .transformations .commands import TargetType
24
24
from nncf .common .graph .transformations .layout import TransformationLayout
25
25
from nncf .common .tensor_statistics .statistic_point import StatisticPoint
26
+ from nncf .experimental .common .check_feature import is_experimental_torch_tracing_enabled
26
27
from nncf .experimental .common .tensor_statistics .collectors import MaxVarianceReducer
27
28
from nncf .experimental .common .tensor_statistics .collectors import MeanAbsMaxReducer
28
29
from nncf .experimental .common .tensor_statistics .collectors import MeanAggregator
35
36
from nncf .experimental .common .tensor_statistics .statistics import MeanMagnitudeTensorStatistic
36
37
from nncf .experimental .common .tensor_statistics .statistics import MeanVarianceTensorStatistic
37
38
from nncf .experimental .common .tensor_statistics .statistics import WCTensorStatistic
39
+ from nncf .experimental .torch2 .commands import PT2InsertionCommand
40
+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
41
+ from nncf .experimental .torch2 .model_transformer import PT2ModelTransformer
38
42
from nncf .parameters import CompressWeightsMode
39
43
from nncf .quantization .algorithms .smooth_quant .torch_backend import SQMultiply
40
44
from nncf .quantization .algorithms .weight_compression .awq_patterns import get_awq_patterns
46
50
from nncf .quantization .algorithms .weight_compression .weight_lowering import compress_weight
47
51
from nncf .tensor import Tensor
48
52
from nncf .tensor .definitions import TensorDataType
49
- from nncf .torch .dynamic_graph .scope import Scope
50
53
from nncf .torch .graph import operator_metatypes as om
51
54
from nncf .torch .graph .operator_metatypes import PTMulMetatype
52
55
from nncf .torch .graph .pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
@@ -186,8 +189,14 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
186
189
return activation_ports [0 ]
187
190
188
191
def get_weight (
189
- self , node_with_weight : NNCFNode , weight_port_id : int , model : torch .nn .Module , graph : NNCFGraph
192
+ self ,
193
+ node_with_weight : NNCFNode ,
194
+ weight_port_id : int ,
195
+ model : Union [GraphModelWrapper , torch .nn .Module ],
196
+ graph : NNCFGraph ,
190
197
) -> Tensor :
198
+ if isinstance (model , GraphModelWrapper ):
199
+ model = model .model
191
200
weight_node = get_const_node (node_with_weight , weight_port_id , graph )
192
201
weight_name = weight_node .layer_attributes .name
193
202
weight = get_const_data (weight_node , model )
@@ -197,7 +206,11 @@ def get_weight(
197
206
return Tensor (weight )
198
207
199
208
def get_weight_dtype (
200
- self , node_with_weight : NNCFNode , weight_port_id : int , model : torch .nn .Module , graph : NNCFGraph
209
+ self ,
210
+ node_with_weight : NNCFNode ,
211
+ weight_port_id : int ,
212
+ model : Union [GraphModelWrapper , torch .nn .Module ],
213
+ graph : NNCFGraph ,
201
214
) -> TensorDataType :
202
215
return self .get_weight (node_with_weight , weight_port_id , model , graph ).dtype
203
216
@@ -209,7 +222,14 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
209
222
def set_weight (
210
223
self , node_with_weight : NNCFNode , weight_port_id : int , model : torch .nn .Module , graph : NNCFGraph , weight : Tensor
211
224
):
212
- update_parameter (node_with_weight .node_name , "weight" , weight .data , model )
225
+ if is_experimental_torch_tracing_enabled ():
226
+ weight_node = get_const_node (node_with_weight , weight_port_id , graph )
227
+ module_name , weight_attr_name = split_const_name (weight_node .layer_attributes .name )
228
+ module = get_module_by_name (module_name , model .model )
229
+ weight_param = getattr (module , weight_attr_name )
230
+ weight_param .data = weight .data
231
+ else :
232
+ update_parameter (node_with_weight .node_name , "weight" , weight .data , model )
213
233
214
234
def insert_adapters (
215
235
self , wc_params : WeightCompressionParameters , lora_A : Tensor , lora_B : Tensor , int8_lora : bool
@@ -229,13 +249,19 @@ def filter_func(point: StatisticPoint) -> bool:
229
249
230
250
def transform_model (
231
251
self ,
232
- model : NNCFNetwork ,
252
+ model : Union [ GraphModelWrapper , torch . nn . Module ] ,
233
253
graph : NNCFGraph ,
234
254
weight_compression_parameters : Iterable [WeightCompressionParameters ],
235
255
precomputed_scales : Dict [str , Tensor ] = None ,
236
256
precomputed_zero_points : Dict [str , Tensor ] = None ,
237
257
lora_correction_algo : LoraCorrectionAlgorithm = None ,
238
258
) -> NNCFNetwork :
259
+ if isinstance (model , GraphModelWrapper ):
260
+ model_transformer = PT2ModelTransformer (model )
261
+ model = model .model
262
+ else :
263
+ model_transformer = PTModelTransformer (model )
264
+
239
265
transformation_layout = TransformationLayout ()
240
266
241
267
for wc_params in weight_compression_parameters :
@@ -291,38 +317,43 @@ def transform_model(
291
317
292
318
# sets compressed tensor
293
319
# TODO:(AlexanderDokuchaev): update set_const_data
294
- compressed_parameter = torch .nn .Parameter (packed_tensor , requires_grad = False )
295
320
module_name , weight_attr_name = split_const_name (weight_name )
296
321
module = get_module_by_name (module_name , model )
297
322
weight = getattr (module , weight_attr_name )
323
+
298
324
if not isinstance (weight , torch .nn .Parameter ):
299
325
msg = f"Weight is not a torch.nn.Parameter in the model by name { weight_name } ."
300
326
raise nncf .InternalError (msg )
301
327
302
- setattr (module , weight_attr_name , compressed_parameter )
303
-
304
- consumer_nodes = graph .get_next_nodes (weight_node )
305
- if len (consumer_nodes ) > 1 :
306
- for c_node in consumer_nodes :
307
- c_module = model .nncf .get_module_by_scope (Scope .from_str (c_node .layer_name ))
308
- for name , param in c_module .named_parameters (recurse = False , remove_duplicate = False ):
309
- if id (param ) == id (weight ):
310
- setattr (c_module , name , compressed_parameter )
311
-
312
- # registry weight decompression module in the model
313
- decompressor_name = f"weights_decompressor_{ weight_node .node_name .replace ('.' , '_' )} "
314
-
315
- # inserts the weight decompressor into the model as the post hook on the model weight
316
- transformation_layout .register (
317
- PTSharedFnInsertionCommand (
318
- [PTTargetPoint (TargetType .OPERATOR_POST_HOOK , target_node_name = weight_node .node_name )],
319
- decompressor ,
320
- decompressor_name ,
328
+ weight .requires_grad = False
329
+ weight .data = packed_tensor
330
+
331
+ if is_experimental_torch_tracing_enabled ():
332
+ transformation_layout .register (
333
+ PT2InsertionCommand (
334
+ [
335
+ PTTargetPoint (
336
+ TargetType .OPERATOR_POST_HOOK , target_node_name = weight_node .node_name .replace ("." , ":" )
337
+ )
338
+ ],
339
+ decompressor ,
340
+ )
341
+ )
342
+ else :
343
+ # registry weight decompression module in the model
344
+ decompressor_name = f"weights_decompressor_{ weight_node .node_name .replace ('.' , '_' )} "
345
+
346
+ # inserts the weight decompressor into the model as the post hook on the model weight
347
+ transformation_layout .register (
348
+ PTSharedFnInsertionCommand (
349
+ [PTTargetPoint (TargetType .OPERATOR_POST_HOOK , target_node_name = weight_node .node_name )],
350
+ decompressor ,
351
+ decompressor_name ,
352
+ )
321
353
)
322
- )
323
354
324
355
# apply transformations
325
- transformed_model = PTModelTransformer ( model ) .transform (transformation_layout )
356
+ transformed_model = model_transformer .transform (transformation_layout )
326
357
327
358
return transformed_model
328
359
@@ -356,6 +387,9 @@ def scale_insertion_command(
356
387
357
388
sq_multiply = SQMultiply (scale .shape )
358
389
sq_multiply .scale = scale
390
+
391
+ if is_experimental_torch_tracing_enabled ():
392
+ return PT2InsertionCommand (target_points , sq_multiply )
359
393
scale_node_name = f"{ source_node .node_name } /awq_mul"
360
394
return PTSharedFnInsertionCommand (target_points , sq_multiply , scale_node_name )
361
395
0 commit comments