Skip to content

Commit 4d9840c

Browse files
authored
Add StatefulDataLoader to select other recipes (pytorch#2431)
1 parent 0e8f840 commit 4d9840c

8 files changed

+125
-116
lines changed

recipes/dev/grpo_full_finetune_distributed.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
import sys
88
import time
99
from functools import partial
10-
from typing import Any, Dict, List, Optional, Tuple, Union
10+
from typing import Any, Dict, List, Optional, Union
1111
from warnings import warn
1212

1313
import torch
1414
from omegaconf import DictConfig, ListConfig
1515
from torch import nn
1616
from torch.distributed import destroy_process_group, init_process_group
1717
from torch.optim import Optimizer
18-
from torch.utils.data import DataLoader, DistributedSampler
19-
18+
from torchdata.stateful_dataloader import StatefulDataLoader
19+
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2020
from torchtune import config, generation, modules, rlhf, training, utils
2121
from torchtune.config._utils import _get_component_from_path
2222
from torchtune.datasets import ConcatDataset
@@ -248,11 +248,16 @@ def setup(self, cfg: DictConfig) -> None:
248248
collate_name = cfg.get(
249249
"collate_fn", "torchtune.dev.grpo.data.padded_collate_rl"
250250
)
251-
self._sampler, self._dataloader = self._setup_data(
251+
self._dataloader = self._setup_data(
252252
cfg_dataset=cfg.dataset,
253253
shuffle=cfg.shuffle,
254254
batch_size=cfg.batch_size,
255255
collate_fn=collate_name,
256+
dataloader_state_dict=(
257+
checkpoint_dict[training.DATALOADER_KEY]
258+
if self._resume_from_checkpoint
259+
else None
260+
),
256261
)
257262

258263
# Finally update the recipe state which can only be correctly set after all of the
@@ -552,7 +557,8 @@ def _setup_data(
552557
shuffle: bool,
553558
batch_size: int,
554559
collate_fn: str,
555-
) -> Tuple[DistributedSampler, DataLoader]:
560+
dataloader_state_dict: Optional[Dict[str, Any]] = None,
561+
) -> StatefulDataLoader:
556562
"""
557563
All data related setup happens here. Currently this recipe only supports the
558564
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
@@ -571,30 +577,32 @@ def _setup_data(
571577
# Instantiate collate_fn
572578
collate_fn = _get_component_from_path(collate_fn)
573579

574-
sampler = DistributedSampler(
580+
sampler = StatefulDistributedSampler(
575581
ds,
576582
num_replicas=self.world_size,
577583
rank=self.rank,
578584
shuffle=shuffle,
579585
seed=self.seed,
580586
)
581-
dataloader = DataLoader(
587+
dataloader = StatefulDataLoader(
582588
dataset=ds,
583589
batch_size=batch_size,
584590
sampler=sampler,
585-
# dropping last avoids shape issues with compile + flex attention
586-
drop_last=True,
587591
collate_fn=(
588592
partial(
589593
collate_fn,
590594
padding_idx=self._tokenizer.pad_id,
591595
)
592596
),
597+
# dropping last avoids shape issues with compile + flex attention
598+
drop_last=True,
593599
)
594-
595-
utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
596-
597-
return sampler, dataloader
600+
if dataloader_state_dict is not None:
601+
dataloader.load_state_dict(dataloader_state_dict)
602+
# B/c we currently only save at epoch boundaries, if we cut the previous epoch short
603+
# we need to force the dataloader to finish the last iteration before it's actually used
604+
list(dataloader)
605+
return dataloader
598606

599607
def save_checkpoint(
600608
self,
@@ -668,6 +676,7 @@ def save_checkpoint(
668676
training.EPOCHS_KEY: self._epochs_run,
669677
training.TOTAL_EPOCHS_KEY: self.total_epochs,
670678
training.RNG_KEY: self._rng.get_state(),
679+
training.DATALOADER_KEY: self._dataloader.state_dict(),
671680
}
672681
)
673682

@@ -930,11 +939,8 @@ def train(self) -> None:
930939
self._profiler.start()
931940
# self.epochs_run should be non-zero when we're resuming from a checkpoint
932941
for curr_epoch in range(self._epochs_run, self.total_epochs):
933-
# Update the sampler to ensure data is correctly shuffled across epochs
934-
# in case shuffle is True
935-
self._sampler.set_epoch(curr_epoch)
936-
937942
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
943+
self._dataloader.sampler.set_epoch(curr_epoch)
938944
for idx, batch in enumerate(self._dataloader):
939945

940946
# Start tracking CUDA memory for active steps for just the first epoch

recipes/full_finetune_distributed.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99

1010
from functools import partial
11-
from typing import Any, Dict, List, Optional, Tuple, Union
11+
from typing import Any, Dict, List, Optional, Union
1212
from warnings import warn
1313

1414
import torch
@@ -23,7 +23,8 @@
2323
from torch.distributed._tensor import DTensor
2424
from torch.distributed.tensor.parallel import parallelize_module
2525
from torch.optim import Optimizer
26-
from torch.utils.data import DataLoader, DistributedSampler
26+
from torchdata.stateful_dataloader import StatefulDataLoader
27+
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2728
from torchtune import config, modules, training, utils
2829
from torchtune.config._utils import _get_component_from_path
2930
from torchtune.data import padded_collate_packed
@@ -347,7 +348,7 @@ def setup(self, cfg: DictConfig) -> None:
347348
# sampler and dataloader depend on the tokenizer and loss_fn and should be
348349
# setup after both of these are initialized
349350
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
350-
self._sampler, self._dataloader = self._setup_data(
351+
self._dataloader = self._setup_data(
351352
cfg_dataset=cfg.dataset,
352353
shuffle=cfg.shuffle,
353354
batch_size=cfg.batch_size,
@@ -686,11 +687,12 @@ def _setup_data(
686687
shuffle: bool,
687688
batch_size: int,
688689
collate_fn: str,
689-
) -> Tuple[DistributedSampler, DataLoader]:
690+
dataloader_state_dict: Optional[Dict[str, Any]] = None,
691+
) -> StatefulDataLoader:
690692
"""
691-
All data related setup happens here. Currently this recipe only supports the
692-
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
693-
iterable datasets and streaming datasets are not supported.
693+
All data related setup happens here. This recipe currently supports only
694+
map-style datasets. If a state_dict is provided (meaning we are resuming a training run),
695+
it is loaded into the dataloader.
694696
"""
695697
if isinstance(cfg_dataset, ListConfig):
696698
datasets = [
@@ -708,15 +710,13 @@ def _setup_data(
708710
raise RuntimeError("left_pad_sequence collator is only for inference.")
709711
collate_fn = _get_component_from_path(collate_fn)
710712

711-
sampler = DistributedSampler(
712-
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
713+
sampler = StatefulDistributedSampler(
714+
ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle
713715
)
714-
dataloader = DataLoader(
716+
dataloader = StatefulDataLoader(
715717
dataset=ds,
716718
batch_size=batch_size,
717719
sampler=sampler,
718-
# dropping last avoids shape issues with compile + flex attention
719-
drop_last=True,
720720
collate_fn=(
721721
partial(
722722
collate_fn,
@@ -726,11 +726,15 @@ def _setup_data(
726726
if not packed
727727
else padded_collate_packed
728728
),
729+
# dropping last avoids shape issues with compile + flex attention
730+
drop_last=True,
729731
)
730-
731-
utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
732-
733-
return sampler, dataloader
732+
if dataloader_state_dict is not None:
733+
dataloader.load_state_dict(dataloader_state_dict)
734+
# B/c we currently only save at epoch boundaries, if we cut the previous epoch short
735+
# we need to force the dataloader to finish the last iteration before it's actually used
736+
list(dataloader)
737+
return dataloader
734738

735739
def train(self) -> None:
736740
"""
@@ -754,19 +758,9 @@ def train(self) -> None:
754758
self._profiler.start()
755759
# self.epochs_run should be non-zero when we're resuming from a checkpoint
756760
for curr_epoch in range(self.epochs_run, self.total_epochs):
757-
# Update the sampler to ensure data is correctly shuffled across epochs
758-
# in case shuffle is True
759-
self._sampler.set_epoch(curr_epoch)
760-
761761
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
762+
self._dataloader.sampler.set_epoch(curr_epoch)
762763
for idx, batch in enumerate(self._dataloader):
763-
if (
764-
self.max_steps_per_epoch is not None
765-
and (idx // self._gradient_accumulation_steps)
766-
== self.max_steps_per_epoch
767-
):
768-
break
769-
770764
# Start tracking CUDA memory for active steps for just the first epoch
771765
if (
772766
self._is_rank_zero
@@ -908,6 +902,11 @@ def train(self) -> None:
908902
# will include multiple forward / backward passes if gradient accumulation > 1
909903
self._profiler.step()
910904

905+
if (
906+
(idx + 1) // self._gradient_accumulation_steps
907+
) == self.max_steps_per_epoch:
908+
break
909+
911910
self.epochs_run += 1
912911
self._checkpoint_client.save_checkpoint(
913912
model=self._model,
@@ -921,6 +920,7 @@ def train(self) -> None:
921920
epochs_run=self.epochs_run,
922921
total_epochs=self.total_epochs,
923922
max_steps_per_epoch=self.max_steps_per_epoch,
923+
dataloader_state_dict=self._dataloader.state_dict(),
924924
),
925925
epoch=curr_epoch,
926926
)

recipes/full_finetune_single_device.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,9 @@ def _setup_data(
594594
)
595595
if dataloader_state_dict is not None:
596596
dataloader.load_state_dict(dataloader_state_dict)
597+
# B/c we currently only save at epoch boundaries, if we cut the previous epoch short
598+
# we need to force the dataloader to finish the last iteration before it's actually used
599+
list(dataloader)
597600
return dataloader
598601

599602
def save_checkpoint(self, epoch: int) -> None:
@@ -604,16 +607,13 @@ def save_checkpoint(self, epoch: int) -> None:
604607
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
605608
# if training is in-progress, checkpoint the optimizer state as well
606609
if epoch + 1 < self.total_epochs:
607-
dataloader_sd = self._dataloader.state_dict()
608-
# Hardcode _iterator_finished to True to avoid issues with resuming from a checkpoint
609-
dataloader_sd["_iterator_finished"] = True
610610
ckpt_dict.update(
611611
{
612612
training.SEED_KEY: self.seed,
613613
training.EPOCHS_KEY: self.epochs_run,
614614
training.TOTAL_EPOCHS_KEY: self.total_epochs,
615615
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
616-
training.DATALOADER_KEY: dataloader_sd,
616+
training.DATALOADER_KEY: self._dataloader.state_dict(),
617617
}
618618
)
619619
if not self._optimizer_in_bwd:

0 commit comments

Comments
 (0)