19
19
from nncf .common .statistics import NNCFStatistics
20
20
from nncf .common .utils .api_marker import api
21
21
from nncf .common .utils .backend import copy_model
22
+ from nncf .parameters import StripFormat
22
23
23
24
TModel = TypeVar ("TModel" )
24
25
@@ -236,14 +237,17 @@ def statistics(self, quickly_collected_only: bool = False) -> NNCFStatistics:
236
237
need to keep track of statistics on each training batch/step/iteration.
237
238
"""
238
239
239
- def strip_model (self , model : TModel , do_copy : bool = False ) -> TModel :
240
+ def strip_model (
241
+ self , model : TModel , do_copy : bool = False , strip_format : StripFormat = StripFormat .NATIVE
242
+ ) -> TModel :
240
243
"""
241
244
Strips auxiliary layers that were used for the model compression, as it's
242
245
only needed for training. The method is used before exporting the model
243
246
in the target format.
244
247
245
248
:param model: The compressed model.
246
249
:param do_copy: Modify copy of the model, defaults to False.
250
+ :param strip format: Describes the format in which model is saved after strip.
247
251
:return: The stripped model.
248
252
"""
249
253
if do_copy :
@@ -256,16 +260,17 @@ def prepare_for_export(self) -> None:
256
260
"""
257
261
self ._model = self .strip_model (self ._model )
258
262
259
- def strip (self , do_copy : bool = True ) -> TModel : # type: ignore[type-var]
263
+ def strip (self , do_copy : bool = True , strip_format : StripFormat = StripFormat . NATIVE ) -> TModel : # type: ignore[type-var]
260
264
"""
261
- Returns the model object with as much custom NNCF additions as possible removed
262
- while still preserving the functioning of the model object as a compressed model.
265
+ Removes auxiliary layers and operations added during the compression process, resulting in a clean
266
+ model ready for deployment. The functionality of the model object is still preserved as a compressed model.
263
267
264
268
:param do_copy: If True (default), will return a copy of the currently associated model object. If False,
265
269
will return the currently associated model object "stripped" in-place.
270
+ :param strip format: Describes the format in which model is saved after strip.
266
271
:return: The stripped model.
267
272
"""
268
- return self .strip_model (self .model , do_copy ) # type: ignore
273
+ return self .strip_model (self .model , do_copy , strip_format ) # type: ignore
269
274
270
275
@abstractmethod
271
276
def export_model (
0 commit comments