|
28 | 28 | from nncf.config.extractors import has_input_info_field
|
29 | 29 | from nncf.config.telemetry_extractors import CompressionStartedFromConfig
|
30 | 30 | from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
|
| 31 | +from nncf.experimental.torch2.function_hook.serialization import get_config as pt2_get_config |
| 32 | +from nncf.experimental.torch2.function_hook.serialization import load_from_config as pt2_load_from_config |
31 | 33 | from nncf.telemetry import tracked_function
|
32 | 34 | from nncf.telemetry.events import NNCF_PT_CATEGORY
|
33 | 35 | from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
|
@@ -397,18 +399,43 @@ def is_wrapped_model(model: Any) -> bool:
|
397 | 399 | FunctionCallTelemetryExtractor("nncf.torch.load_from_config"),
|
398 | 400 | ],
|
399 | 401 | )
|
400 |
| -def load_from_config(model: torch.nn.Module, config: Dict[str, Any], example_input: Any) -> NNCFNetwork: |
| 402 | +def load_from_config(model: Module, config: Dict[str, Any], example_input: Optional[Any] = None) -> Module: |
401 | 403 | """
|
402 |
| - Wraps given model to a NNCFNetwork and recovers additional modules from given NNCFNetwork config. |
| 404 | + Wraps given model and recovers additional modules from given config. |
403 | 405 | Does not recover additional modules weights as they are located in a corresponded state_dict.
|
404 | 406 |
|
405 | 407 | :param model: PyTorch model.
|
406 | 408 | :param config: NNCNetwork config.
|
407 | 409 | :param example_input: An example input that will be used for model tracing. A tuple is interpreted
|
408 | 410 | as an example input of a set of non keyword arguments, and a dict as an example input of a set
|
409 |
| - of keywords arguments. |
410 |
| - :return: NNCFNetwork builded from given model with additional modules recovered from given NNCFNetwork config. |
| 411 | + of keywords arguments. Required with enabled legacy tracing mode. |
| 412 | + :return: Wrapped model with additional modules recovered from given config. |
411 | 413 | """
|
| 414 | + if is_experimental_torch_tracing_enabled(): |
| 415 | + return pt2_load_from_config(model, config) |
| 416 | + |
| 417 | + if example_input is None: |
| 418 | + msg = "The 'example_input' parameter must be specified." |
| 419 | + raise nncf.InternalError(msg) |
| 420 | + |
412 | 421 | nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY])
|
413 | 422 | transformation_layout = deserialize_transformations(config)
|
414 | 423 | return PTModelTransformer(nncf_network).transform(transformation_layout)
|
| 424 | + |
| 425 | + |
| 426 | +@tracked_function( |
| 427 | + NNCF_PT_CATEGORY, |
| 428 | + [ |
| 429 | + FunctionCallTelemetryExtractor("nncf.torch.get_config"), |
| 430 | + ], |
| 431 | +) |
| 432 | +def get_config(model: Module) -> Dict[str, Any]: |
| 433 | + """ |
| 434 | + Returns the configuration object of the compressed model. |
| 435 | +
|
| 436 | + :param model: The compressed model. |
| 437 | + :return: The configuration object of the compressed model. |
| 438 | + """ |
| 439 | + if is_experimental_torch_tracing_enabled(): |
| 440 | + return pt2_get_config(model) |
| 441 | + return model.nncf.get_config() |
0 commit comments