8
8
import time
9
9
10
10
from functools import partial
11
- from typing import Any , Dict , List , Optional , Tuple , Union
11
+ from typing import Any , Dict , List , Optional , Union
12
12
from warnings import warn
13
13
14
14
import torch
23
23
from torch .distributed ._tensor import DTensor
24
24
from torch .distributed .tensor .parallel import parallelize_module
25
25
from torch .optim import Optimizer
26
- from torch .utils .data import DataLoader , DistributedSampler
26
+ from torchdata .stateful_dataloader import StatefulDataLoader
27
+ from torchdata .stateful_dataloader .sampler import StatefulDistributedSampler
27
28
from torchtune import config , modules , training , utils
28
29
from torchtune .config ._utils import _get_component_from_path
29
30
from torchtune .data import padded_collate_packed
@@ -347,7 +348,7 @@ def setup(self, cfg: DictConfig) -> None:
347
348
# sampler and dataloader depend on the tokenizer and loss_fn and should be
348
349
# setup after both of these are initialized
349
350
collate_name = cfg .get ("collate_fn" , "torchtune.data.padded_collate_sft" )
350
- self ._sampler , self . _dataloader = self ._setup_data (
351
+ self ._dataloader = self ._setup_data (
351
352
cfg_dataset = cfg .dataset ,
352
353
shuffle = cfg .shuffle ,
353
354
batch_size = cfg .batch_size ,
@@ -686,11 +687,12 @@ def _setup_data(
686
687
shuffle : bool ,
687
688
batch_size : int ,
688
689
collate_fn : str ,
689
- ) -> Tuple [DistributedSampler , DataLoader ]:
690
+ dataloader_state_dict : Optional [Dict [str , Any ]] = None ,
691
+ ) -> StatefulDataLoader :
690
692
"""
691
- All data related setup happens here. Currently this recipe only supports the
692
- DistributedSamplers with Map -style Datasets which fit into memory. Other samplers ,
693
- iterable datasets and streaming datasets are not supported .
693
+ All data related setup happens here. This recipe currently supports only
694
+ map -style datasets. If a state_dict is provided (meaning we are resuming a training run) ,
695
+ it is loaded into the dataloader .
694
696
"""
695
697
if isinstance (cfg_dataset , ListConfig ):
696
698
datasets = [
@@ -708,15 +710,13 @@ def _setup_data(
708
710
raise RuntimeError ("left_pad_sequence collator is only for inference." )
709
711
collate_fn = _get_component_from_path (collate_fn )
710
712
711
- sampler = DistributedSampler (
712
- ds , num_replicas = self .dp_size , rank = self .dp_rank , shuffle = shuffle , seed = 0
713
+ sampler = StatefulDistributedSampler (
714
+ ds , num_replicas = self .dp_size , rank = self .dp_rank , shuffle = shuffle
713
715
)
714
- dataloader = DataLoader (
716
+ dataloader = StatefulDataLoader (
715
717
dataset = ds ,
716
718
batch_size = batch_size ,
717
719
sampler = sampler ,
718
- # dropping last avoids shape issues with compile + flex attention
719
- drop_last = True ,
720
720
collate_fn = (
721
721
partial (
722
722
collate_fn ,
@@ -726,11 +726,15 @@ def _setup_data(
726
726
if not packed
727
727
else padded_collate_packed
728
728
),
729
+ # dropping last avoids shape issues with compile + flex attention
730
+ drop_last = True ,
729
731
)
730
-
731
- utils .log_rank_zero (log , "Dataset and Sampler are initialized." )
732
-
733
- return sampler , dataloader
732
+ if dataloader_state_dict is not None :
733
+ dataloader .load_state_dict (dataloader_state_dict )
734
+ # B/c we currently only save at epoch boundaries, if we cut the previous epoch short
735
+ # we need to force the dataloader to finish the last iteration before it's actually used
736
+ list (dataloader )
737
+ return dataloader
734
738
735
739
def train (self ) -> None :
736
740
"""
@@ -754,19 +758,9 @@ def train(self) -> None:
754
758
self ._profiler .start ()
755
759
# self.epochs_run should be non-zero when we're resuming from a checkpoint
756
760
for curr_epoch in range (self .epochs_run , self .total_epochs ):
757
- # Update the sampler to ensure data is correctly shuffled across epochs
758
- # in case shuffle is True
759
- self ._sampler .set_epoch (curr_epoch )
760
-
761
761
pbar = tqdm (total = self ._steps_per_epoch , disable = not self ._is_rank_zero )
762
+ self ._dataloader .sampler .set_epoch (curr_epoch )
762
763
for idx , batch in enumerate (self ._dataloader ):
763
- if (
764
- self .max_steps_per_epoch is not None
765
- and (idx // self ._gradient_accumulation_steps )
766
- == self .max_steps_per_epoch
767
- ):
768
- break
769
-
770
764
# Start tracking CUDA memory for active steps for just the first epoch
771
765
if (
772
766
self ._is_rank_zero
@@ -908,6 +902,11 @@ def train(self) -> None:
908
902
# will include multiple forward / backward passes if gradient accumulation > 1
909
903
self ._profiler .step ()
910
904
905
+ if (
906
+ (idx + 1 ) // self ._gradient_accumulation_steps
907
+ ) == self .max_steps_per_epoch :
908
+ break
909
+
911
910
self .epochs_run += 1
912
911
self ._checkpoint_client .save_checkpoint (
913
912
model = self ._model ,
@@ -921,6 +920,7 @@ def train(self) -> None:
921
920
epochs_run = self .epochs_run ,
922
921
total_epochs = self .total_epochs ,
923
922
max_steps_per_epoch = self .max_steps_per_epoch ,
923
+ dataloader_state_dict = self ._dataloader .state_dict (),
924
924
),
925
925
epoch = curr_epoch ,
926
926
)
0 commit comments