Skip to content

Commit 9f4aa76

Browse files
[PT] change api for get_config and load_from_config (#3359)
### Changes Add `get_config` to `nncf/torch/__init__.py` Use function `get_config` and `load_from_config` for new and old tracing as `from nncf.torch import get_config, load_from_config` ### Tests https://github.com/openvinotoolkit/nncf/actions/runs/13974357105
1 parent faa6bed commit 9f4aa76

File tree

6 files changed

+43
-12
lines changed

6 files changed

+43
-12
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ Here is an example of Accuracy Aware Quantization pipeline where model weights a
245245

246246
```python
247247
import nncf
248+
import nncf.torch
248249
import torch
249250
from torchvision import datasets, models
250251

@@ -271,7 +272,7 @@ quantized_model = nncf.quantize(model, calibration_dataset)
271272
# Save quantization modules and the quantized model parameters
272273
checkpoint = {
273274
'state_dict': model.state_dict(),
274-
'nncf_config': model.nncf.get_config(),
275+
'nncf_config': nncf.torch.get_config(model),
275276
... # the rest of the user-defined objects to save
276277
}
277278
torch.save(checkpoint, path_to_checkpoint)

docs/usage/training_time_compression/quantization_aware_training/Usage.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,19 @@ ov_quantized_model = ov.convert_model(stripped_model)
7575

7676
The complete information about compression is defined by a compressed model and a NNCF config.
7777
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules introduced by NNCF.
78-
The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the
78+
The NNCF config can be obtained by `nncf.torch.get_config` on saving and passed to the
7979
`nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config.
8080
The quantized model saving allows to load quantized modules to the target model in a new python process and
8181
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.
8282

8383
```python
84+
import nncf.torch
85+
8486
# save part
8587
quantized_model = nncf.quantize(model, calibration_dataset)
8688
checkpoint = {
87-
'state_dict':quantized_model.state_dict(),
88-
'nncf_config': quantized_model.nncf.get_config(),
89+
'state_dict': quantized_model.state_dict(),
90+
'nncf_config': nncf.torch.get_config(quantized_model),
8991
...
9092
}
9193
torch.save(checkpoint, path)

examples/quantization_aware_training/torch/resnet18/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,11 @@ def transform_fn(data_item):
278278
print(f"Train epoch: {epoch}")
279279
train_epoch(train_loader, quantized_model, criterion, optimizer, device=device)
280280
acc1_int8 = validate(val_loader, quantized_model, device)
281-
print(f"Accyracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
281+
print(f"Accuracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
282282
# Save the compression checkpoint for model with the best accuracy metric.
283283
if acc1_int8 > acc1_int8_best:
284284
state_dict = quantized_model.state_dict()
285-
compression_config = quantized_model.nncf.get_config()
285+
compression_config = nncf.torch.get_config(quantized_model)
286286
torch.save(
287287
{
288288
"model_state_dict": state_dict,

nncf/torch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from nncf.torch.model_creation import is_wrapped_model
5151
from nncf.torch.model_creation import wrap_model
5252
from nncf.torch.model_creation import load_from_config
53+
from nncf.torch.model_creation import get_config
5354
from nncf.torch.checkpoint_loading import load_state
5455
from nncf.torch.initialization import register_default_init_args
5556
from nncf.torch.layers import register_module

nncf/torch/model_creation.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from nncf.config.extractors import has_input_info_field
2929
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
3030
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
31+
from nncf.experimental.torch2.function_hook.serialization import get_config as pt2_get_config
32+
from nncf.experimental.torch2.function_hook.serialization import load_from_config as pt2_load_from_config
3133
from nncf.telemetry import tracked_function
3234
from nncf.telemetry.events import NNCF_PT_CATEGORY
3335
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
@@ -397,18 +399,43 @@ def is_wrapped_model(model: Any) -> bool:
397399
FunctionCallTelemetryExtractor("nncf.torch.load_from_config"),
398400
],
399401
)
400-
def load_from_config(model: torch.nn.Module, config: Dict[str, Any], example_input: Any) -> NNCFNetwork:
402+
def load_from_config(model: Module, config: Dict[str, Any], example_input: Optional[Any] = None) -> Module:
401403
"""
402-
Wraps given model to a NNCFNetwork and recovers additional modules from given NNCFNetwork config.
404+
Wraps given model and recovers additional modules from given config.
403405
Does not recover additional modules weights as they are located in a corresponded state_dict.
404406
405407
:param model: PyTorch model.
406408
:param config: NNCNetwork config.
407409
:param example_input: An example input that will be used for model tracing. A tuple is interpreted
408410
as an example input of a set of non keyword arguments, and a dict as an example input of a set
409-
of keywords arguments.
410-
:return: NNCFNetwork builded from given model with additional modules recovered from given NNCFNetwork config.
411+
of keywords arguments. Required with enabled legacy tracing mode.
412+
:return: Wrapped model with additional modules recovered from given config.
411413
"""
414+
if is_experimental_torch_tracing_enabled():
415+
return pt2_load_from_config(model, config)
416+
417+
if example_input is None:
418+
msg = "The 'example_input' parameter must be specified."
419+
raise nncf.InternalError(msg)
420+
412421
nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY])
413422
transformation_layout = deserialize_transformations(config)
414423
return PTModelTransformer(nncf_network).transform(transformation_layout)
424+
425+
426+
@tracked_function(
427+
NNCF_PT_CATEGORY,
428+
[
429+
FunctionCallTelemetryExtractor("nncf.torch.get_config"),
430+
],
431+
)
432+
def get_config(model: Module) -> Dict[str, Any]:
433+
"""
434+
Returns the configuration object of the compressed model.
435+
436+
:param model: The compressed model.
437+
:return: The configuration object of the compressed model.
438+
"""
439+
if is_experimental_torch_tracing_enabled():
440+
return pt2_get_config(model)
441+
return model.nncf.get_config()

tests/torch2/function_hook/test_serialization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from nncf.experimental.torch2.function_hook import register_post_function_hook
2020
from nncf.experimental.torch2.function_hook import register_pre_function_hook
2121
from nncf.experimental.torch2.function_hook import wrap_model
22-
from nncf.experimental.torch2.function_hook.serialization import get_config
23-
from nncf.experimental.torch2.function_hook.serialization import load_from_config
22+
from nncf.torch import get_config
23+
from nncf.torch import load_from_config
2424
from tests.torch2.function_hook.helpers import HookWithState
2525
from tests.torch2.function_hook.helpers import SimpleModel
2626

0 commit comments

Comments
 (0)