11
11
12
12
from copy import deepcopy
13
13
from dataclasses import dataclass
14
- from typing import Any , Dict , List , Optional , TypeVar
14
+ from typing import Dict , List , Optional , TypeVar
15
15
16
16
import nncf
17
- from nncf import Dataset
18
17
from nncf import nncf_logger
19
18
from nncf .common .factory import ModelTransformerFactory
20
19
from nncf .common .graph .graph import NNCFGraph
29
28
from nncf .parameters import CompressWeightsMode
30
29
from nncf .quantization .algorithms .algorithm import Algorithm
31
30
from nncf .quantization .algorithms .weight_compression .activation_stats import process_stats
31
+ from nncf .quantization .algorithms .weight_compression .backend import WeightCompressionAlgoBackend
32
32
from nncf .quantization .algorithms .weight_compression .config import WeightCompressionParameters
33
33
from nncf .quantization .algorithms .weight_compression .weight_lowering import calculate_nf4_scale
34
34
from nncf .quantization .algorithms .weight_compression .weight_lowering import do_nf4_dequantization
@@ -61,34 +61,20 @@ class AWQ(Algorithm):
61
61
62
62
def __init__ (
63
63
self ,
64
- model : TModel ,
65
- name_to_node_mapping : Dict [str , Any ],
66
- all_weight_params : List [WeightCompressionParameters ],
67
- nodes_to_compress : List [NNCFNode ],
68
- statistics : Dict [str , WCTensorStatistic ],
69
64
subset_size : int = 32 ,
70
- percent_to_apply = 0.002 ,
71
- alpha_min = 0.0 ,
72
- alpha_max = 1.0 ,
73
- steps = 100 ,
65
+ percent_to_apply : float = 0.002 ,
66
+ alpha_min : float = 0.0 ,
67
+ alpha_max : float = 1.0 ,
68
+ steps : int = 100 ,
74
69
):
75
70
"""
76
- :param model: Model for applying algorithm.
77
- :param name_to_node_mapping: Name to node mapping for updating node weights.
78
- :param all_weight_params: List of all weight parameters.
79
- :param nodes_to_compress: List of nodes for processing.
80
- :param statistics: Input activation statistics for each node.
81
71
:param subset_size: The number of samples for AWQ.
82
72
:param percent_to_apply: The percent of outliers for correction.
83
73
:param alpha_min: Minimum value of smoothness parameter for grid search.
84
74
:param alpha_max: Maximal value of smoothness parameter for grid search.
85
75
:param steps: The number of the steps in grid search.
86
76
"""
87
77
super ().__init__ ()
88
- self .name_to_node_mapping = name_to_node_mapping
89
- self ._all_weight_params = all_weight_params
90
- self ._nodes_to_compress = nodes_to_compress
91
- self ._statistics = statistics
92
78
self ._subset_size = subset_size
93
79
self ._percent_to_apply = percent_to_apply
94
80
self ._alpha_min = alpha_min
@@ -98,44 +84,54 @@ def __init__(
98
84
self ._patterns = None
99
85
self ._scale_per_target_node = {}
100
86
101
- self ._set_backend_entity (model )
102
-
103
87
@property
104
88
def available_backends (self ) -> List [BackendType ]:
105
- return [BackendType .OPENVINO ]
89
+ return [BackendType .OPENVINO , BackendType . TORCH ]
106
90
107
- def _set_backend_entity (self , model : TModel ) -> None :
91
+ def _set_backend_entity (
92
+ self , model : TModel , wc_backend_entity : Optional [WeightCompressionAlgoBackend ] = None
93
+ ) -> None :
108
94
"""
109
95
Creates a helper class with a backed-specific logic of the algorithm.
110
96
111
97
:param model: Backend-specific input model.
98
+ :param wc_backend_entity: Weight compression algorithm backend.
112
99
"""
113
100
model_backend = get_backend (model )
114
101
if model_backend == BackendType .OPENVINO :
115
102
from nncf .quantization .algorithms .weight_compression .openvino_backend import OVAWQAlgoAlgoBackend
116
103
117
- self ._backend_entity = OVAWQAlgoAlgoBackend (model , self .name_to_node_mapping )
118
- self ._patterns = self ._backend_entity .get_awq_patterns ()
104
+ self ._backend_entity = OVAWQAlgoAlgoBackend (model , wc_backend_entity .name_to_node_mapping )
105
+ elif model_backend == BackendType .TORCH :
106
+ from nncf .quantization .algorithms .weight_compression .torch_backend import PTAWQAlgoAlgoBackend
107
+
108
+ self ._backend_entity = PTAWQAlgoAlgoBackend ()
109
+
119
110
else :
120
111
msg = f"Cannot return backend-specific AWQ entity because { model_backend .value } is not supported!"
121
112
raise nncf .UnsupportedBackendError (msg )
113
+ self ._patterns = self ._backend_entity .get_awq_patterns ()
122
114
123
115
def apply (
124
116
self ,
125
117
model : TModel ,
126
118
graph : NNCFGraph ,
127
- statistic_points : Optional [StatisticPointsContainer ] = None ,
128
- dataset : Optional [Dataset ] = None ,
119
+ all_weight_params : List [WeightCompressionParameters ],
120
+ nodes_to_compress : List [NNCFNode ],
121
+ statistics : Dict [str , WCTensorStatistic ],
122
+ wc_backend_entity : Optional [WeightCompressionAlgoBackend ] = None ,
129
123
) -> TModel :
130
124
"""
131
125
Applies the algorithm to the model.
132
-
133
126
:param model: Model for applying algorithm.
134
127
:param graph: Model graph.
135
- :param statistic_points: Statistic points with collected statistics values.
136
- :param dataset: A representative dataset for the calibration process.
128
+ :param all_weight_params: List of all weight parameters.
129
+ :param nodes_to_compress: List of nodes for processing.
130
+ :param statistics: Input activation statistics for each node.
131
+ :param wc_backend_entity: Weight compression algorithm backend.
137
132
:return: A resulting model.
138
133
"""
134
+ self ._set_backend_entity (model , wc_backend_entity )
139
135
matches = []
140
136
141
137
inference_nncf_graph = transform_to_inference_graph (deepcopy (graph ), [], [], [], [])
@@ -151,7 +147,7 @@ def apply(
151
147
model_transformer = ModelTransformerFactory .create (model , inplace = True )
152
148
153
149
awq_data = {}
154
- name_mapping = {wp .weight_name : idx for idx , wp in enumerate (self . _all_weight_params )}
150
+ name_mapping = {wp .weight_name : idx for idx , wp in enumerate (all_weight_params )}
155
151
156
152
for match in matches :
157
153
nncf_node = graph .get_node_by_key (match [- 1 ])
@@ -166,11 +162,11 @@ def apply(
166
162
if target_node_names [- 1 ] not in name_mapping :
167
163
continue
168
164
169
- weight_params = self . _all_weight_params [name_mapping [target_node_names [- 1 ]]]
165
+ weight_params = all_weight_params [name_mapping [target_node_names [- 1 ]]]
170
166
171
167
if weight_params .compression_config .num_bits != 4 :
172
168
continue
173
- target_node = self . _nodes_to_compress [name_mapping [target_node_names [- 1 ]]]
169
+ target_node = nodes_to_compress [name_mapping [target_node_names [- 1 ]]]
174
170
175
171
# avoid matching different patterns for the same node
176
172
if target_node .node_name in awq_data :
@@ -182,7 +178,7 @@ def apply(
182
178
merge_node_names = []
183
179
for weight_op_friendly_name , _ in self ._backend_entity .get_weight_names_and_port_ids (nncf_node , graph ):
184
180
merge_node_names .append (weight_op_friendly_name )
185
- merge_node = self . _nodes_to_compress [name_mapping [merge_node_names [- 1 ]]]
181
+ merge_node = nodes_to_compress [name_mapping [merge_node_names [- 1 ]]]
186
182
else : # pattern Act->MatMul or Act->Multiply->MatMul
187
183
merge_node = nncf_node
188
184
@@ -204,7 +200,7 @@ def apply(
204
200
205
201
config = wp .compression_config
206
202
207
- s , X = process_stats (self . _statistics [k ], self ._subset_size )
203
+ s , X = process_stats (statistics [k ], self ._subset_size )
208
204
209
205
top_k = max (int (s .shape [0 ] * self ._percent_to_apply ), 1 )
210
206
topk_idxs = fns .argsort (- s )[:top_k ]
0 commit comments