14
14
15
15
16
16
import copy
17
- import inspect
18
17
import uuid
19
18
from typing import Any , Callable , Dict , Generator , Iterator , List , Optional , Sized , Tuple , Union
20
19
21
20
from neural_compressor .common import Logger
22
- from neural_compressor .common .base_config import BaseConfig , ComposableConfig
21
+ from neural_compressor .common .base_config import BaseConfig
23
22
from neural_compressor .common .utils import TuningLogger
24
23
25
24
logger = Logger ().get_logger ()
@@ -227,19 +226,11 @@ def __iter__(self) -> Generator[BaseConfig, Any, None]:
227
226
228
227
229
228
class TuningConfig :
230
- """Base Class for Tuning Criterion.
231
-
232
- Args:
233
- config_set: quantization configs. Default value is empty.
234
- A single config or a list of configs. More details can
235
- be found in the `from_fwk_configs`of `ConfigSet` class.
236
- max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
237
- tolerable_loss: This float indicates how much metric loss we can accept. \
238
- The metric loss is relative, it can be both positive and negative. Default is 0.01.
229
+ """Config for auto tuning pipeline.
239
230
240
231
Examples:
241
232
# TODO: to refine it
242
- from neural_compressor import TuningConfig
233
+ from neural_compressor.torch.quantization import TuningConfig
243
234
tune_config = TuningConfig(
244
235
config_set=[config1, config2, ...],
245
236
max_trials=3,
@@ -264,13 +255,25 @@ class TuningConfig:
264
255
"""
265
256
266
257
def __init__ (
267
- self , config_set = None , max_trials = 100 , sampler : Sampler = default_sampler , tolerable_loss = 0.01
268
- ) -> None :
269
- """Init a TuneCriterion object."""
258
+ self ,
259
+ config_set : Union [BaseConfig , List [BaseConfig ]] = None ,
260
+ sampler : Sampler = default_sampler ,
261
+ tolerable_loss = 0.01 ,
262
+ max_trials = 100 ,
263
+ ):
264
+ """Initial a TuningConfig.
265
+
266
+ Args:
267
+ config_set: A single config or a list of configs. Defaults to None.
268
+ sampler: tuning sampler that decide the trials order. Defaults to default_sampler.
269
+ tolerable_loss: This float indicates how much metric loss we can accept.
270
+ The metric loss is relative, it can be both positive and negative. Default is 0.01.
271
+ max_trials: Max tuning times. Combine with `tolerable_loss` field to decide when to stop. Default is 100.
272
+ """
270
273
self .config_set = config_set
271
- self .max_trials = max_trials
272
274
self .sampler = sampler
273
275
self .tolerable_loss = tolerable_loss
276
+ self .max_trials = max_trials
274
277
275
278
276
279
class _TrialRecord :
0 commit comments