10
10
# limitations under the License.
11
11
12
12
from copy import deepcopy
13
- from typing import Any , Dict , List , Optional , Tuple , TypeVar
13
+ from typing import Dict , List , Optional , Tuple , TypeVar
14
14
15
15
import nncf
16
- from nncf import Dataset
17
16
from nncf .common .graph .graph import NNCFGraph
18
- from nncf .common .graph .graph import NNCFNode
19
17
from nncf .common .logging .track_progress import track
20
- from nncf .common .tensor_statistics .statistic_point import StatisticPointsContainer
21
18
from nncf .common .utils .backend import BackendType
22
19
from nncf .common .utils .backend import get_backend
23
20
from nncf .experimental .common .tensor_statistics .statistics import WCTensorStatistic
24
21
from nncf .parameters import CompressWeightsMode
25
22
from nncf .quantization .algorithms .weight_compression .activation_stats import process_stats
23
+ from nncf .quantization .algorithms .weight_compression .backend import WeightCompressionAlgoBackend
26
24
from nncf .quantization .algorithms .weight_compression .config import WeightCompressionConfig
27
25
from nncf .quantization .algorithms .weight_compression .config import WeightCompressionParameters
28
26
from nncf .quantization .algorithms .weight_compression .weight_lowering import calculate_normalized_weight_and_fp4_scale
@@ -45,70 +43,57 @@ class ScaleEstimation:
45
43
46
44
def __init__ (
47
45
self ,
48
- model : TModel ,
49
- name_to_node_mapping : Dict [str , Any ],
50
- all_weight_params : List [WeightCompressionParameters ],
51
- nodes_to_compress : List [NNCFNode ],
52
- statistics : Dict [str , WCTensorStatistic ],
53
46
subset_size : int = 32 ,
54
47
initial_steps : int = 5 ,
55
48
scale_steps : int = 10 ,
56
49
weight_penalty : float = - 1.0 ,
57
50
):
58
51
"""
59
- :param model: Model for applying algorithm.
60
- :param name_to_node_mapping: Name to node mapping for updating node weights.
61
- :param all_weight_params: List of all weight parameters.
62
- :param nodes_to_compress: List of nodes for processing.
63
- :param statistics: Input activation statistics for each node.
64
52
:param subset_size: The number of samples for scale estimation.
65
53
:param initial_steps: The number of the steps for absmax scale rectification.
66
54
:param scale_steps: The number of the steps for grid search scale rectification
67
55
from 1.0 to 1.0 - 0.05 * scale_step.
68
56
:param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply.
69
57
"""
70
58
super ().__init__ ()
71
- self .name_to_node_mapping = name_to_node_mapping
72
- self ._all_weight_params = all_weight_params
73
- self ._nodes_to_compress = nodes_to_compress
74
- self ._statistics = statistics
75
59
self ._subset_size = subset_size
76
60
self ._initial_steps = initial_steps
77
61
self ._scale_steps = scale_steps
78
62
self ._weight_penalty = weight_penalty
79
63
80
- self ._set_backend_entity (model )
81
-
82
64
@property
83
65
def available_backends (self ) -> List [BackendType ]:
84
- return [BackendType .OPENVINO ]
66
+ return [BackendType .OPENVINO , BackendType . TORCH ]
85
67
86
68
def _set_backend_entity (self , model : TModel ) -> None :
87
69
"""
88
70
Creates a helper class with a backed-specific logic of the algorithm.
89
71
90
72
:param model: Backend-specific input model.
91
- :param all_weight_params: List of all weight parameters.
92
- :param nodes_to_compress: List of nodes for processing.
93
- :param activations: The input activations of the layers considered for compression.
94
73
"""
95
-
96
74
model_backend = get_backend (model )
97
75
if model_backend == BackendType .OPENVINO :
98
76
from nncf .quantization .algorithms .weight_compression .openvino_backend import OVWeightCompressionAlgoBackend
99
77
100
- self ._backend_entity = OVWeightCompressionAlgoBackend (model , self .name_to_node_mapping )
78
+ self ._backend_entity = OVWeightCompressionAlgoBackend (model )
79
+ elif model_backend == BackendType .TORCH :
80
+ from nncf .quantization .algorithms .weight_compression .torch_backend import PTWeightCompressionAlgoBackend
81
+
82
+ self ._backend_entity = PTWeightCompressionAlgoBackend ()
101
83
else :
102
84
raise nncf .UnsupportedBackendError (
103
- "Cannot return backend-specific AWQ entity because {} is not supported!" .format (model_backend .value )
85
+ "Cannot return backend-specific Scale Estimation entity because {} is not supported!" .format (
86
+ model_backend .value
87
+ )
104
88
)
105
89
106
90
def apply (
107
91
self ,
108
92
model : TModel ,
109
93
graph : NNCFGraph ,
110
- statistic_points : Optional [StatisticPointsContainer ] = None ,
111
- dataset : Optional [Dataset ] = None ,
94
+ all_weight_params : List [WeightCompressionParameters ],
95
+ statistics : Dict [str , WCTensorStatistic ],
96
+ backend_entity : Optional [WeightCompressionAlgoBackend ] = None ,
112
97
) -> Tuple [Dict [str , Tensor ], Dict [str , Tensor ]]:
113
98
"""
114
99
Estimates better scale for the int4 nodes in the model.
@@ -119,23 +104,28 @@ def apply(
119
104
120
105
:param model: Model for applying algorithm.
121
106
:param graph: Model graph.
107
+ :param all_weight_params: List of all weight parameters.
108
+ :param statistics: Input activation statistics for each node.
122
109
:param statistic_points: Statistic points with collected statistics values.
123
110
:param dataset: A representative dataset for the calibration process.
111
+ :param backend_entity: Weight compression algorithm backend.
124
112
:return: Two dictionaries for estimated scales and zero points for each weight name.
125
113
"""
126
-
114
+ self ._backend_entity = backend_entity
115
+ if self ._backend_entity is None :
116
+ self ._set_backend_entity (model )
127
117
scales , zero_points = dict (), dict ()
128
118
129
- for wp in track (self . _all_weight_params , description = "Applying Scale Estimation" ):
119
+ for wp in track (all_weight_params , description = "Applying Scale Estimation" ):
130
120
weight_name = wp .weight_name
131
121
node_name = wp .node_with_weight .node_name
132
122
config = wp .compression_config
133
123
134
- if config .num_bits != 4 or node_name not in self . _statistics :
124
+ if config .num_bits != 4 or node_name not in statistics :
135
125
scales [weight_name ] = None
136
126
continue
137
127
138
- stats = self . _statistics [node_name ]
128
+ stats = statistics [node_name ]
139
129
140
130
weight_data = self ._backend_entity .get_weight_names_and_port_ids (wp .node_with_weight , graph )
141
131
if len (weight_data ) != 1 : # not supported by the algorithm
0 commit comments