10
10
# limitations under the License.
11
11
12
12
13
- from typing import List
14
-
15
13
import numpy as np
16
14
import torch
17
15
from torch .quantization .fake_quantize import FakeQuantize
18
16
19
17
import nncf
20
- from nncf .common .graph .transformations .commands import Command
21
18
from nncf .common .graph .transformations .commands import TargetType
22
19
from nncf .common .graph .transformations .layout import TransformationLayout
23
20
from nncf .parameters import StripFormat
@@ -191,10 +188,8 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF
191
188
:param strip format: Describes the format in which model is saved after strip.
192
189
:return: The modified NNCF network.
193
190
"""
194
- model_layout = model .nncf .transformation_layout ()
195
- transformations = model_layout .transformations
196
191
if strip_format == StripFormat .DQ :
197
- model = replace_with_decompressors (model , transformations )
192
+ model = replace_with_decompressors (model )
198
193
elif strip_format == StripFormat .NATIVE :
199
194
model = replace_quantizer_to_torch_native_module (model )
200
195
model = remove_disabled_quantizers (model )
@@ -204,7 +199,7 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF
204
199
return model
205
200
206
201
207
- def replace_with_decompressors (model : NNCFNetwork , transformations : List [ Command ] ) -> NNCFNetwork :
202
+ def replace_with_decompressors (model : NNCFNetwork ) -> NNCFNetwork :
208
203
"""
209
204
Performs transformation from fake quantize format (FQ) to dequantization one (DQ).
210
205
The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value,
@@ -222,21 +217,21 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command
222
217
:return: The modified NNCF network.
223
218
"""
224
219
transformation_layout = TransformationLayout ()
220
+ transformations = model .nncf .transformation_layout ().transformations
225
221
model = model .nncf .get_clean_shallow_copy ()
226
222
graph = model .nncf .get_graph ()
227
-
228
223
for command in transformations :
229
224
quantizer = command .fn
230
225
231
- msg = None
226
+ msg = ""
232
227
if not isinstance (quantizer , (SymmetricQuantizer , AsymmetricQuantizer )):
233
- msg = f"Unexpected compression module on strip: { quantizer .__class__ } "
234
- elif quantizer ._qspec .half_range or quantizer ._qspec .narrow_range :
235
- msg = "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False."
236
- elif quantizer .num_bits not in [4 , 8 ]:
237
- msg = f"Unsupported number of bits { quantizer .num_bits } for the quantizer { quantizer } "
238
- elif len (command .target_points ) > 1 :
239
- msg = "Command contains more than one target point! "
228
+ msg = f"Unexpected compression module on strip: { quantizer .__class__ } . \n "
229
+ if quantizer ._qspec .half_range or quantizer ._qspec .narrow_range :
230
+ msg + = "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n "
231
+ if quantizer .num_bits not in [4 , 8 ]:
232
+ msg + = f"Unsupported number of bits { quantizer .num_bits } for the quantizer { quantizer } . \n "
233
+ if len (command .target_points ) > 1 :
234
+ msg + = "Command contains more than one target point. "
240
235
if msg :
241
236
raise nncf .ValidationError (msg )
242
237
0 commit comments