Skip to content

Commit e4c6678

Browse files
Update TF QAT docs. Deprecate TF create_compressed_model method (#3217)
### Changes - Update TF QAT docs - Deprecate the create_compressed_model() method for the TF backend. ### Reason for changes Ref: 158980 ### Related tickets Ref: 158980 ### Tests N/A
1 parent b37408f commit e4c6678

File tree

10 files changed

+234
-31
lines changed

10 files changed

+234
-31
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ def transform_fn(data_item):
201201
calibration_dataset = nncf.Dataset(val_dataset, transform_fn)
202202
# Step 3: Run the quantization pipeline
203203
quantized_model = nncf.quantize(model, calibration_dataset)
204+
# Step 4: Remove auxiliary layers and operations added during the quantization process,
205+
# resulting in a clean, fully quantized model ready for deployment.
206+
stripped_model = nncf.strip(quantized_model)
204207
```
205208

206209
</details>

docs/usage/training_time_compression/quantization_aware_training/Usage.md

+93-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Use NNCF for Quantization Aware Training in PyTorch
1+
# Use NNCF for Quantization Aware Training
22

3-
This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch project (please see the [TensorFlow quantization documentation](../other_algorithms/LegacyQuantization.md) for integration tutorial for the existing TensorFlow project).
4-
The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model.
3+
This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch or TensorFlow projects.
4+
The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model.
55
The task is to prepare this model for accelerated inference by simulating the compression at train time.
66
Please refer to this [document](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md) for details of the implementation.
77

@@ -11,11 +11,24 @@ Please refer to this [document](/docs/usage/training_time_compression/other_algo
1111

1212
Quantize the model using the [Post Training Quantization](../../post_training_compression/post_training_quantization/Usage.md) method.
1313

14+
<details open><summary><b>PyTorch</b></summary>
15+
1416
```python
1517
model = TorchModel() # instance of torch.nn.Module
1618
quantized_model = nncf.quantize(model, ...)
1719
```
1820

21+
</details>
22+
23+
<details><summary><b>TensorFlow</b></summary>
24+
25+
```python
26+
model = TensorFlowModel() # instance of tf.keras.Model
27+
quantized_model = nncf.quantize(model, ...)
28+
```
29+
30+
</details>
31+
1932
### Step 2: Run the training pipeline
2033

2134
At this point, the NNCF is fully integrated into your training pipeline.
@@ -27,27 +40,46 @@ Important points you should consider when training your networks with compressio
2740

2841
### Step 3: Export the compressed model
2942

30-
After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it. There are two ways to export a model:
43+
After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it.
44+
45+
<details open><summary><b>PyTorch</b></summary>
46+
47+
Trace the model via inference in framework operations.
3148

32-
1. Trace the model via inference in framework operations.
49+
```python
50+
# To OpenVINO format
51+
import openvino as ov
52+
ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input)
53+
```
54+
55+
</details>
56+
57+
<details><summary><b>TensorFlow</b></summary>
58+
59+
```python
60+
# To OpenVINO format
61+
import openvino as ov
62+
63+
# Removes auxiliary layers and operations added during the quantization process,
64+
# resulting in a clean, fully quantized model ready for deployment.
65+
stripped_model = nncf.strip(quantized_model)
66+
67+
ov_quantized_model = ov.convert_model(stripped_model)
68+
```
3369

34-
```python
35-
# To OpenVINO format
36-
import openvino as ov
37-
ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input)
38-
```
70+
</details>
3971

4072
## Saving and loading compressed models
4173

74+
<details open><summary><b>PyTorch</b></summary>
75+
4276
The complete information about compression is defined by a compressed model and a NNCF config.
4377
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules intoduced by NNCF.
4478
The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the
4579
`nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config.
4680
The quantized model saving allows to load quantized modules to the target model in a new python process and
4781
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.
4882

49-
### Saving and loading compressed models in PyTorch
50-
5183
```python
5284
# save part
5385
quantized_model = nncf.quantize(model, calibration_dataset)
@@ -70,10 +102,53 @@ quantized_model.load_state_dict(state_dict)
70102

71103
You can save the `compressed_model` object `torch.save` as usual: via `state_dict` and `load_state_dict` methods.
72104

105+
</details>
106+
107+
<details><summary><b>TensorFlow</b></summary>
108+
109+
To save a model checkpoint, use the following API:
110+
111+
```python
112+
from nncf.tensorflow import ConfigState
113+
from nncf.tensorflow import get_config
114+
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
115+
116+
nncf_config = get_config(quantized_model)
117+
checkpoint = tf.train.Checkpoint(model=quantized_model,
118+
nncf_config_state=ConfigState(nncf_config),
119+
... # the rest of the user-defined objects to save
120+
)
121+
callbacks = []
122+
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
123+
...
124+
quantized_model.fit(..., callbacks=callbacks)
125+
```
126+
127+
To restore the model from checkpoint, use the following API:
128+
129+
```python
130+
from nncf.tensorflow import ConfigState
131+
from nncf.tensorflow import load_from_config
132+
133+
checkpoint = tf.train.Checkpoint(nncf_config_state=ConfigState())
134+
checkpoint.restore(path_to_checkpoint)
135+
136+
quantized_model = load_from_config(model, checkpoint.nncf_config_state.config)
137+
138+
checkpoint = tf.train.Checkpoint(model=quantized_model
139+
... # the rest of the user-defined objects to load
140+
)
141+
checkpoint.restore(path_to_checkpoint)
142+
```
143+
144+
</details>
145+
73146
## Advanced usage
74147

75148
### Compression of custom modules
76149

150+
<details open><summary><b>PyTorch</b></summary>
151+
77152
With no target model code modifications, NNCF only supports native PyTorch modules with respect to trainable parameter (weight) compressed, such as `torch.nn.Conv2d`.
78153
If your model contains a custom, non-PyTorch standard module with trainable weights that should be compressed, you can register it using the `@nncf.register_module` decorator:
79154

@@ -91,4 +166,9 @@ If registered module should be ignored by specific algorithms use `ignored_algor
91166

92167
In the example above, the NNCF-compressed models that contain instances of `MyModule` will have the corresponding modules extended with functionality that will allow NNCF to quantize the `weight` parameter of `MyModule` before it takes part in `MyModule`'s `forward` calculation.
93168

94-
See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset.
169+
</details>
170+
171+
## Examples
172+
173+
- See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset.
174+
- See a TensorFlow [example](/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md) for **Quantization** Compression scenario on imagenette/320px-v2 dataset.

examples/post_training_quantization/tensorflow/mobilenet_v2/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def transform_fn(data_item):
151151
###############################################################################
152152
# Benchmark performance, calculate compression rate and validate accuracy
153153

154-
ov_model = ov.convert_model(tf_model, share_weights=False)
155-
ov_quantized_model = ov.convert_model(tf_quantized_model, share_weights=False)
154+
ov_model = ov.convert_model(tf_model)
155+
ov_quantized_model = ov.convert_model(tf_quantized_model)
156156

157157
fp32_ir_path = ROOT / "mobilenet_v2_fp32.xml"
158158
ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False)

examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def transform_fn(data_item):
167167
###############################################################################
168168
# Benchmark performance, calculate compression rate and validate accuracy
169169

170-
ov_model = ov.convert_model(tf_model, share_weights=False)
171-
ov_quantized_model = ov.convert_model(stripped_model, share_weights=False)
170+
ov_model = ov.convert_model(tf_model)
171+
ov_quantized_model = ov.convert_model(stripped_model)
172172

173173
fp32_ir_path = ROOT / "mobilenet_v2_fp32.xml"
174174
ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False)

nncf/tensorflow/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@
4444
)
4545
from nncf.tensorflow.helpers import create_compressed_model as create_compressed_model
4646
from nncf.tensorflow.helpers.callback_creation import create_compression_callbacks as create_compression_callbacks
47+
from nncf.tensorflow.helpers.model_creation import get_config
48+
from nncf.tensorflow.helpers.model_creation import load_from_config
4749
from nncf.tensorflow.initialization import register_default_init_args as register_default_init_args
4850
from nncf.tensorflow.pruning.filter_pruning import algorithm as filter_pruning_algorithm
4951

5052
# Required for correct COMPRESSION_ALGORITHMS registry functioning
5153
from nncf.tensorflow.quantization import algorithm as quantization_algorithm
5254
from nncf.tensorflow.sparsity.magnitude import algorithm as magnitude_sparsity_algorithm
5355
from nncf.tensorflow.sparsity.rb import algorithm as rb_sparsity_algorithm
56+
from nncf.tensorflow.utils.state import ConfigState

nncf/tensorflow/helpers/model_creation.py

+71
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,25 @@
1818
from nncf import NNCFConfig
1919
from nncf.api.compression import CompressionAlgorithmController
2020
from nncf.common.compression import BaseCompressionAlgorithmController as BaseController
21+
from nncf.common.deprecation import warning_deprecated
2122
from nncf.common.utils.api_marker import api
2223
from nncf.config.extractors import extract_algorithm_names
2324
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
2425
from nncf.config.utils import is_experimental_quantization
2526
from nncf.telemetry import tracked_function
2627
from nncf.telemetry.events import NNCF_TF_CATEGORY
28+
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
2729
from nncf.tensorflow.accuracy_aware_training.keras_model_utils import accuracy_aware_fit
2830
from nncf.tensorflow.algorithm_selector import NoCompressionAlgorithmBuilder
2931
from nncf.tensorflow.algorithm_selector import get_compression_algorithm_builder
3032
from nncf.tensorflow.api.composite_compression import TFCompositeCompressionAlgorithmBuilder
3133
from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
34+
from nncf.tensorflow.graph.model_transformer import TFModelTransformer
35+
from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout
3236
from nncf.tensorflow.graph.utils import is_keras_layer_model
3337
from nncf.tensorflow.helpers.utils import get_built_model
38+
from nncf.tensorflow.quantization.algorithm import QuantizationBuilder
39+
from nncf.tensorflow.quantization.algorithm import TFQuantizationSetup
3440

3541

3642
def create_compression_algorithm_builder(config: NNCFConfig, should_init: bool) -> TFCompressionAlgorithmBuilder:
@@ -80,6 +86,27 @@ def create_compressed_model(
8086
:return: A tuple of the compression controller for the requested algorithm(s) and the model object with additional
8187
modifications necessary to enable algorithm-specific compression during fine-tuning.
8288
"""
89+
90+
warning_deprecated(
91+
"The 'nncf.tensorflow.create_compressed_model' function is deprecated and will be removed in a "
92+
"future release.\n"
93+
"To perform post training quantization (PTQ) or quantization aware training (QAT),"
94+
" use the nncf.quantize() API:\n"
95+
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#post-training-quantization\n"
96+
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#training-time-quantization\n"
97+
"Examples:\n"
98+
" - https://github.com/openvinotoolkit/nncf/tree/develop/examples/post_training_quantization/tensorflow\n"
99+
" - https://github.com/openvinotoolkit/nncf/tree/develop/examples/quantization_aware_training/tensorflow"
100+
)
101+
return create_compressed_model_impl(model, config, compression_state)
102+
103+
104+
def create_compressed_model_impl(
105+
model: tf.keras.Model, config: NNCFConfig, compression_state: Optional[Dict[str, Any]] = None
106+
) -> Tuple[CompressionAlgorithmController, tf.keras.Model]:
107+
"""
108+
Implementation of the create_compressed_model() method.
109+
"""
83110
if is_experimental_quantization(config):
84111
if is_keras_layer_model(model):
85112
msg = (
@@ -128,3 +155,47 @@ def get_input_signature(config: NNCFConfig):
128155
input_signature.append(tf.TensorSpec(shape=shape, dtype=tf.float32))
129156

130157
return input_signature if len(input_signature) > 1 else input_signature[0]
158+
159+
160+
@tracked_function(
161+
NNCF_TF_CATEGORY,
162+
[
163+
FunctionCallTelemetryExtractor("nncf.tensorflow.load_from_config"),
164+
],
165+
)
166+
def load_from_config(model: tf.keras.Model, config: Dict[str, Any]) -> tf.keras.Model:
167+
"""
168+
Recovers additional modules from given config.
169+
Does not recover additional modules weights as they are located in a corresponded checkpoint file.
170+
171+
:param model: TensorFlow model.
172+
:parem config: Config.
173+
:return: tf.keras.Model builded from given model with additional layers recovered from given config.
174+
"""
175+
quantizer_setup_state = config["quantization"]["quantizer_setup"]
176+
quantizer_setup = TFQuantizationSetup.from_state(quantizer_setup_state)
177+
178+
transformation_layout = TFTransformationLayout()
179+
# pylint: disable=protected-access
180+
insertion_commands, _ = QuantizationBuilder.build_insertion_commands_for_quantizer_setup(quantizer_setup)
181+
for command in insertion_commands:
182+
transformation_layout.register(command)
183+
model_transformer = TFModelTransformer(model)
184+
return model_transformer.transform(transformation_layout)
185+
186+
187+
@tracked_function(
188+
NNCF_TF_CATEGORY,
189+
[
190+
FunctionCallTelemetryExtractor("nncf.tensorflow.get_config"),
191+
],
192+
)
193+
def get_config(model: tf.keras.Model) -> Dict[str, Any]:
194+
"""
195+
Extracts the config from the model.
196+
197+
:param model: Model.
198+
:return: Config.
199+
"""
200+
config = getattr(model, "_nncf_config")
201+
return config

nncf/tensorflow/quantization/algorithm.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,17 @@ def _get_half_range(
348348
return True
349349
return False
350350

351-
def _create_quantizer(self, name: str, qspec: TFQuantizerSpec) -> Quantizer:
351+
@staticmethod
352+
def _create_quantizer(name: str, qspec: TFQuantizerSpec) -> Quantizer:
352353
quantizer_cls = NNCF_QUANTIZATION_OPERATIONS.get(qspec.mode)
353354
return quantizer_cls(name, qspec)
354355

355-
def _build_insertion_commands_for_quantizer_setup(
356-
self, quantizer_setup: TFQuantizationSetup
357-
) -> List[TFInsertionCommand]:
356+
@staticmethod
357+
def build_insertion_commands_for_quantizer_setup(
358+
quantizer_setup: TFQuantizationSetup,
359+
) -> Tuple[List[TFInsertionCommand], List[str]]:
358360
insertion_commands = []
361+
op_names = []
359362
quantization_points = quantizer_setup.get_quantization_points()
360363
non_unified_scales_quantization_point_ids = set(range(len(quantization_points)))
361364

@@ -367,7 +370,7 @@ def _build_insertion_commands_for_quantizer_setup(
367370
quantizer_spec = qp.quantizer_spec
368371
op_name = qp.op_name + "/unified_scale_group"
369372
quantizer = FakeQuantize(quantizer_spec, name=op_name)
370-
self._op_names.append(quantizer.op_name)
373+
op_names.append(quantizer.op_name)
371374
target_points = []
372375
for us_qp_id in unified_scales_group:
373376
non_unified_scales_quantization_point_ids.discard(us_qp_id)
@@ -389,24 +392,26 @@ def _build_insertion_commands_for_quantizer_setup(
389392
quantizer_spec = quantization_point.quantizer_spec
390393
target_point = quantization_point.target_point
391394
if quantization_point.is_weight_quantization():
392-
quantizer = self._create_quantizer(op_name, quantizer_spec)
393-
self._op_names.append(op_name)
395+
quantizer = QuantizationBuilder._create_quantizer(op_name, quantizer_spec)
396+
op_names.append(op_name)
394397
else:
395398
quantizer = FakeQuantize(quantizer_spec, name=op_name)
396-
self._op_names.append(quantizer.op_name)
399+
op_names.append(quantizer.op_name)
397400
command = TFInsertionCommand(
398401
target_point=target_point,
399402
callable_object=quantizer,
400403
priority=TransformationPriority.QUANTIZATION_PRIORITY,
401404
)
402405
insertion_commands.append(command)
403-
return insertion_commands
406+
return insertion_commands, op_names
404407

405408
def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
406409
transformations = TFTransformationLayout()
407410
if self._quantizer_setup is None:
408411
self._quantizer_setup = self._get_quantizer_setup(model)
409-
insertion_commands = self._build_insertion_commands_for_quantizer_setup(self._quantizer_setup)
412+
insertion_commands, self._op_names = QuantizationBuilder.build_insertion_commands_for_quantizer_setup(
413+
self._quantizer_setup
414+
)
410415
for command in insertion_commands:
411416
transformations.register(command)
412417
return transformations

nncf/tensorflow/quantization/quantize_model.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from nncf.quantization.advanced_parameters import apply_advanced_parameters_to_config
2929
from nncf.scopes import IgnoredScope
3030
from nncf.scopes import convert_ignored_scope_to_list
31-
from nncf.tensorflow.helpers.model_creation import create_compressed_model
31+
from nncf.tensorflow.helpers.model_creation import create_compressed_model_impl
3232

3333
DEFAULT_RANGE_TYPE = "mean_min_max"
3434

@@ -181,6 +181,12 @@ def quantize_impl(
181181
]
182182
)
183183

184-
_, compressed_model = create_compressed_model(model=model, config=nncf_config)
184+
compression_ctrl, compressed_model = create_compressed_model_impl(model=model, config=nncf_config)
185+
186+
# NOTE: We set the config here to properly save/load the quantized model during training into tf.train.Checkpoint.
187+
# You can obtain that config via the nncf.tensorflow.get_config() method and save/load it to/from
188+
# tf.train.Checkpoint using the nncf.tensorflow.ConfigState class.
189+
config = compression_ctrl.get_compression_state()["builder_state"]
190+
setattr(compressed_model, "_nncf_config", config)
185191

186192
return compressed_model

0 commit comments

Comments
 (0)