|
89 | 89 |
|
90 | 90 | from ..utils.constant import _TASK_ALIASES
|
91 | 91 | from ..utils.import_utils import is_transformers_version
|
92 |
| -from .configuration import OVConfig |
| 92 | +from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig |
93 | 93 | from .quantization import OVDataLoader
|
94 | 94 | from .training_args import OVTrainingArguments
|
95 | 95 | from .utils import (
|
@@ -225,37 +225,41 @@ def __init__(
|
225 | 225 | self.teacher.eval()
|
226 | 226 | self.compression_controller = None
|
227 | 227 |
|
228 |
| - if self.ov_config is not None and self.args.do_train: |
229 |
| - self._set_task() |
230 |
| - train_dataloader = self.get_train_dataloader() |
231 |
| - model_inputs = next(iter(train_dataloader)) |
232 |
| - for label_name in self.label_names: |
233 |
| - model_inputs.pop(label_name) |
234 |
| - force_batch_one = self._is_pruning_enabled() |
235 |
| - self.ov_config.add_input_info(model_inputs, force_batch_one) |
236 |
| - nncf_config = NNCFConfig.from_dict(self.ov_config.__dict__) |
237 |
| - nncf_config.register_extra_structs( |
238 |
| - [ |
239 |
| - QuantizationRangeInitArgs(OVDataLoader(train_dataloader)), |
240 |
| - BNAdaptationInitArgs(OVDataLoader(train_dataloader)), |
241 |
| - ] |
242 |
| - ) |
| 228 | + if self.ov_config is not None: |
| 229 | + if self.ov_config.compression is None: |
| 230 | + self.ov_config.compression = DEFAULT_QUANTIZATION_CONFIG |
| 231 | + |
| 232 | + if self.args.do_train: |
| 233 | + self._set_task() |
| 234 | + train_dataloader = self.get_train_dataloader() |
| 235 | + model_inputs = next(iter(train_dataloader)) |
| 236 | + for label_name in self.label_names: |
| 237 | + model_inputs.pop(label_name) |
| 238 | + force_batch_one = self._is_pruning_enabled() |
| 239 | + self.ov_config.add_input_info(model_inputs, force_batch_one) |
| 240 | + nncf_config = NNCFConfig.from_dict(self.ov_config.__dict__) |
| 241 | + nncf_config.register_extra_structs( |
| 242 | + [ |
| 243 | + QuantizationRangeInitArgs(OVDataLoader(train_dataloader)), |
| 244 | + BNAdaptationInitArgs(OVDataLoader(train_dataloader)), |
| 245 | + ] |
| 246 | + ) |
243 | 247 |
|
244 |
| - # Configure NNCF logging |
245 |
| - # Disable nncf logging to stdout except error |
246 |
| - # but to file nncf_output.log |
247 |
| - nncf_config["log_dir"] = args.output_dir |
248 |
| - nncf_log_file_handler = logging.logging.FileHandler(os.path.join(args.output_dir, NNCF_LOG_FILE_NAME)) |
249 |
| - nncf_log_file_handler.setFormatter(logging.logging.Formatter("%(levelname)s:%(name)s:%(message)s")) |
250 |
| - nncf_logger.addHandler(nncf_log_file_handler) |
251 |
| - set_log_level(logging.ERROR) |
252 |
| - nncf_logger.setLevel(logging.INFO) |
253 |
| - nncf_log_file_handler.setLevel(logging.INFO) |
254 |
| - |
255 |
| - self.compression_controller, self.model = create_compressed_model(self.model, nncf_config) |
256 |
| - self.model_wrapped = self.model |
257 |
| - # TODO : To deprecate once support transformers > 4.30.0 |
258 |
| - self.deepspeed = None |
| 248 | + # Configure NNCF logging |
| 249 | + # Disable nncf logging to stdout except error |
| 250 | + # but to file nncf_output.log |
| 251 | + nncf_config["log_dir"] = args.output_dir |
| 252 | + nncf_log_file_handler = logging.logging.FileHandler(os.path.join(args.output_dir, NNCF_LOG_FILE_NAME)) |
| 253 | + nncf_log_file_handler.setFormatter(logging.logging.Formatter("%(levelname)s:%(name)s:%(message)s")) |
| 254 | + nncf_logger.addHandler(nncf_log_file_handler) |
| 255 | + set_log_level(logging.ERROR) |
| 256 | + nncf_logger.setLevel(logging.INFO) |
| 257 | + nncf_log_file_handler.setLevel(logging.INFO) |
| 258 | + |
| 259 | + self.compression_controller, self.model = create_compressed_model(self.model, nncf_config) |
| 260 | + self.model_wrapped = self.model |
| 261 | + # TODO : To deprecate once support transformers > 4.30.0 |
| 262 | + self.deepspeed = None |
259 | 263 |
|
260 | 264 | def _set_signature_columns_if_needed(self):
|
261 | 265 | if self._signature_columns is None:
|
|
0 commit comments