Skip to content

Commit 802376f

Browse files
committed
minor correction
1 parent 00d1ced commit 802376f

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

nncf/parameters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class StripFormat(StrEnum):
126126
operations added during the compression process, resulting in a clean model ready for deployment.
127127
The functionality of the model object is still preserved as a compressed model.
128128
129-
:param NATIVE: Returns the model with as much custom NNCF additions as possible,
129+
:param NATIVE: Returns the model with as much custom NNCF additions as possible.
130130
:param DQ: Replaces FakeQuantize operations with dequantization subgraph and compressed weights in low-bit
131131
precision using fake quantize parameters. This is the default format for deployment of models with compressed
132132
weights.

nncf/torch/quantization/strip.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,11 @@
1010
# limitations under the License.
1111

1212

13-
from typing import List
14-
1513
import numpy as np
1614
import torch
1715
from torch.quantization.fake_quantize import FakeQuantize
1816

1917
import nncf
20-
from nncf.common.graph.transformations.commands import Command
2118
from nncf.common.graph.transformations.commands import TargetType
2219
from nncf.common.graph.transformations.layout import TransformationLayout
2320
from nncf.parameters import StripFormat
@@ -191,10 +188,8 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF
191188
:param strip format: Describes the format in which model is saved after strip.
192189
:return: The modified NNCF network.
193190
"""
194-
model_layout = model.nncf.transformation_layout()
195-
transformations = model_layout.transformations
196191
if strip_format == StripFormat.DQ:
197-
model = replace_with_decompressors(model, transformations)
192+
model = replace_with_decompressors(model)
198193
elif strip_format == StripFormat.NATIVE:
199194
model = replace_quantizer_to_torch_native_module(model)
200195
model = remove_disabled_quantizers(model)
@@ -204,7 +199,7 @@ def strip_quantized_model(model: NNCFNetwork, strip_format: StripFormat = StripF
204199
return model
205200

206201

207-
def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command]) -> NNCFNetwork:
202+
def replace_with_decompressors(model: NNCFNetwork) -> NNCFNetwork:
208203
"""
209204
Performs transformation from fake quantize format (FQ) to dequantization one (DQ).
210205
The former takes floating-point input, quantizes and dequantizes, and returns a floating-point value,
@@ -222,21 +217,21 @@ def replace_with_decompressors(model: NNCFNetwork, transformations: List[Command
222217
:return: The modified NNCF network.
223218
"""
224219
transformation_layout = TransformationLayout()
220+
transformations = model.nncf.transformation_layout().transformations
225221
model = model.nncf.get_clean_shallow_copy()
226222
graph = model.nncf.get_graph()
227-
228223
for command in transformations:
229224
quantizer = command.fn
230225

231-
msg = None
226+
msg = ""
232227
if not isinstance(quantizer, (SymmetricQuantizer, AsymmetricQuantizer)):
233-
msg = f"Unexpected compression module on strip: {quantizer.__class__}"
234-
elif quantizer._qspec.half_range or quantizer._qspec.narrow_range:
235-
msg = "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False."
236-
elif quantizer.num_bits not in [4, 8]:
237-
msg = f"Unsupported number of bits {quantizer.num_bits} for the quantizer {quantizer}"
238-
elif len(command.target_points) > 1:
239-
msg = "Command contains more than one target point!"
228+
msg = f"Unexpected compression module on strip: {quantizer.__class__}.\n"
229+
if quantizer._qspec.half_range or quantizer._qspec.narrow_range:
230+
msg += "Unexpected parameters of quantizers on strip: half_range and narrow_range should be False.\n"
231+
if quantizer.num_bits not in [4, 8]:
232+
msg += f"Unsupported number of bits {quantizer.num_bits} for the quantizer {quantizer}.\n"
233+
if len(command.target_points) > 1:
234+
msg += "Command contains more than one target point."
240235
if msg:
241236
raise nncf.ValidationError(msg)
242237

0 commit comments

Comments
 (0)