Skip to content

Commit 3bf79ad

Browse files
author
Sara Adkins
authored
Fix for OOM Errors during Ultrachat200k Finetuning (#2180) (#2181)
* testing fix * get rid of repeated log * revert yaml
1 parent d64d9fb commit 3bf79ad

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

src/sparseml/modifiers/pruning/constant/pytorch.py

+3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def on_update(self, state: State, event: Event, **kwargs):
7171
def apply_masks(module):
7272
mask_name = param_mask_name()
7373
if hasattr(module, mask_name):
74+
mask = getattr(module, mask_name)
75+
if mask.device != module.weight.device:
76+
setattr(module, mask_name, mask.to(module.weight.device))
7477
module.weight *= getattr(module, mask_name)
7578

7679
state.model.model.apply(apply_masks)

src/sparseml/transformers/finetune/runner.py

-8
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@
4040
)
4141
from sparseml.transformers.finetune.model_args import ModelArguments
4242
from sparseml.transformers.finetune.training_args import TrainingArguments
43-
from sparseml.utils.fsdp.context import summon_full_params_context
4443
from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model
45-
from sparseml.utils.pytorch import qat_active
4644

4745

4846
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -287,12 +285,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
287285
session = session_manager.active_session()
288286
session.reset_stage()
289287

290-
# log model sparsity
291-
with summon_full_params_context(self.trainer.model):
292-
if self.trainer.accelerator.is_main_process:
293-
if not qat_active(self.trainer.model):
294-
self.trainer.log_model_sparsification()
295-
296288
# synchronize and clean up memory
297289
self.trainer.accelerator.wait_for_everyone()
298290
self.trainer.model = get_session_model()

src/sparseml/transformers/finetune/session_mixin.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from sparseml.utils.fsdp.context import summon_full_params_context
4242
from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
43+
from sparseml.utils.pytorch import qat_active
4344

4445

4546
__all__ = [
@@ -137,7 +138,7 @@ def initialize_session(
137138
train_data = self.get_train_dataloader()
138139

139140
self.accelerator.wait_for_everyone()
140-
with summon_full_params_context(self.model):
141+
with summon_full_params_context(self.model, offload_to_cpu=True):
141142
session_manager.initialize(
142143
model=self.model,
143144
teacher_model=self.teacher, # TODO: what about for self/disable?
@@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
370371

371372
self.accelerator.wait_for_everyone()
372373

373-
# Need to gather parameters across the GPUs before accessing layer weights
374-
with summon_full_params_context(self.model):
375-
self.log_model_sparsification()
374+
# log model sparsity
375+
with summon_full_params_context(self.model, offload_to_cpu=True):
376+
if self.accelerator.is_main_process:
377+
if not qat_active(self.model):
378+
self.log_model_sparsification()
379+
380+
self.accelerator.wait_for_everyone()
376381

377382
return output
378383

@@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
434439
accelerator=self.accelerator,
435440
)
436441

442+
# log model sparsity
443+
with summon_full_params_context(self.model, offload_to_cpu=True):
444+
if self.accelerator.is_main_process:
445+
if not qat_active(self.model):
446+
self.log_model_sparsification()
447+
437448
self.accelerator.wait_for_everyone()
438449

439450
def save_model(

0 commit comments

Comments
 (0)