|
40 | 40 | )
|
41 | 41 | from sparseml.utils.fsdp.context import summon_full_params_context
|
42 | 42 | from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
|
| 43 | +from sparseml.utils.pytorch import qat_active |
43 | 44 |
|
44 | 45 |
|
45 | 46 | __all__ = [
|
@@ -137,7 +138,7 @@ def initialize_session(
|
137 | 138 | train_data = self.get_train_dataloader()
|
138 | 139 |
|
139 | 140 | self.accelerator.wait_for_everyone()
|
140 |
| - with summon_full_params_context(self.model): |
| 141 | + with summon_full_params_context(self.model, offload_to_cpu=True): |
141 | 142 | session_manager.initialize(
|
142 | 143 | model=self.model,
|
143 | 144 | teacher_model=self.teacher, # TODO: what about for self/disable?
|
@@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
|
370 | 371 |
|
371 | 372 | self.accelerator.wait_for_everyone()
|
372 | 373 |
|
373 |
| - # Need to gather parameters across the GPUs before accessing layer weights |
374 |
| - with summon_full_params_context(self.model): |
375 |
| - self.log_model_sparsification() |
| 374 | + # log model sparsity |
| 375 | + with summon_full_params_context(self.model, offload_to_cpu=True): |
| 376 | + if self.accelerator.is_main_process: |
| 377 | + if not qat_active(self.model): |
| 378 | + self.log_model_sparsification() |
| 379 | + |
| 380 | + self.accelerator.wait_for_everyone() |
376 | 381 |
|
377 | 382 | return output
|
378 | 383 |
|
@@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
|
434 | 439 | accelerator=self.accelerator,
|
435 | 440 | )
|
436 | 441 |
|
| 442 | + # log model sparsity |
| 443 | + with summon_full_params_context(self.model, offload_to_cpu=True): |
| 444 | + if self.accelerator.is_main_process: |
| 445 | + if not qat_active(self.model): |
| 446 | + self.log_model_sparsification() |
| 447 | + |
437 | 448 | self.accelerator.wait_for_everyone()
|
438 | 449 |
|
439 | 450 | def save_model(
|
|
0 commit comments