diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index ee72b7f78..f8e94ffad 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -45,9 +45,7 @@ def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_ return Prefetcher(node, prefetch_factor=3) def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None: - """ - Validation should fail if the keys of source_nodes and weights are not the same - """ + """Test validation logic for MultiNodeWeightedSampler if the keys of source_nodes and weights are not the same""" with self.assertRaisesRegex( ValueError, "keys of source_nodes and weights must be the same", @@ -60,9 +58,7 @@ def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None: def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape( self, ) -> None: - """ - Validation should fail if the shape of the weights tensor is invalid - """ + """Test validation logic for MultiNodeWeightedSampler if the shape of the weights tensor is invalid""" with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"): MultiNodeWeightedSampler( self.datasets, @@ -72,9 +68,7 @@ def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape( def test_multi_node_weighted_batch_sampler_negative_weights( self, ) -> None: - """ - Validation should fail if the value of the weights tensor is invalid - """ + """Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid""" with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"): MultiNodeWeightedSampler( self.datasets, @@ -84,9 +78,7 @@ def test_multi_node_weighted_batch_sampler_negative_weights( def test_multi_node_weighted_batch_sampler_zero_weights( self, ) -> None: - """ - Validation should fail if the value of the weights tensor is invalid - """ + """Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid""" with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"): MultiNodeWeightedSampler( self.datasets, @@ -94,6 +86,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights( ) def test_multi_node_weighted_sampler_first_exhausted(self) -> None: + """Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED""" mixer = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -115,6 +108,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None: mixer.reset() def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None: + """Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED""" mixer = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -139,6 +133,7 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None: mixer.reset() def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None: + """Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED""" mixer = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -154,49 +149,15 @@ def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None: self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"]) mixer.reset() - @parameterized.expand( - itertools.product( - [1, 4, 7], - [ - StopCriteria.ALL_DATASETS_EXHAUSTED, - StopCriteria.FIRST_DATASET_EXHAUSTED, - StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, - ], - ) - ) - def test_save_load_state_mixer(self, midpoint: int, stop_criteria: str): - mixer = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria) - run_test_save_load_state(self, mixer, midpoint) - - @parameterized.expand( - itertools.product( - [100, 500, 1200], - [ - StopCriteria.ALL_DATASETS_EXHAUSTED, - StopCriteria.FIRST_DATASET_EXHAUSTED, - StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, - ], - ) - ) - def test_multi_node_weighted_large_sample_size(self, midpoint, stop_criteria) -> None: - num_samples = 1500 - num_datasets = 5 - - mixer = self._setup_multi_node_weighted_sampler( - num_samples, - num_datasets, - self._weights_fn, - stop_criteria, - ) - run_test_save_load_state(self, mixer, midpoint) - @parameterized.expand([(1, 8), (8, 32)]) def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size): + """Test MultiNodeWeightedSampler with different rank and world size""" mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size) self.assertEqual(mixer.rank, rank) self.assertEqual(mixer.world_size, world_size) def test_multi_node_weighted_batch_sampler_results_for_ranks(self): + """Test MultiNodeWeightedSampler with different results for different ranks""" world_size = 8 global_results = [] for rank in range(world_size): @@ -211,14 +172,16 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self): self.assertEqual(unique_results, global_results) def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): - # Check the mixer node + """Test MultiNodeWeightedSampler with different results in each epoch""" + + # Check for the mixer node only mixer = MultiNodeWeightedSampler( self.datasets, self.weights, ) overall_results = [] - for i in range(self._num_epochs): + for _ in range(self._num_epochs): results = list(mixer) overall_results.append(results) mixer.reset() @@ -230,7 +193,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): self.assertEqual(unique_results, overall_results) - # Check mixer along with Prefetcher node + # Check for mixer along with Prefetcher node node = self._setup_multi_node_weighted_sampler( self._num_samples, self._num_datasets, @@ -239,7 +202,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): ) overall_results = [] - for i in range(self._num_epochs): + for _ in range(self._num_epochs): results = list(node) overall_results.append(results) node.reset() @@ -262,6 +225,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self): ) ) def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str): + """Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs""" node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria) run_test_save_load_state(self, node, midpoint) @@ -283,3 +247,26 @@ def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoi stop_criteria=stop_criteria, ) run_test_save_load_state(self, node, midpoint) + + @parameterized.expand( + itertools.product( + [100, 500, 1200], + [ + StopCriteria.ALL_DATASETS_EXHAUSTED, + StopCriteria.FIRST_DATASET_EXHAUSTED, + StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, + ], + ) + ) + def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, stop_criteria) -> None: + """Test MultiNodeWeightedSampler (larger sample sizes) with saving and loading of state across multiple epochs""" + num_samples = 1500 + num_datasets = 5 + + mixer = self._setup_multi_node_weighted_sampler( + num_samples, + num_datasets, + self._weights_fn, + stop_criteria, + ) + run_test_save_load_state(self, mixer, midpoint) diff --git a/test/nodes/utils.py b/test/nodes/utils.py index 857db9396..c19df163f 100644 --- a/test/nodes/utils.py +++ b/test/nodes/utils.py @@ -134,7 +134,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): self.i = 0 self.num_resets += 1 - def next(self) -> Iterator[Dict[str, int]]: + def next(self) -> Dict[str, int]: if self.i == self.n: raise StopIteration() ret = {"i": self.i, "resets": self.num_resets} diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index 27f1f8922..b16d8d89d 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -10,6 +10,20 @@ class Batcher(BaseNode[List[T]]): + """Batcher node batches the data from the source node into batches of size batch_size. + If the source node is exhausted, it will return the batch or raise StopIteration. + If drop_last is True, the last batch will be dropped if it is smaller than batch_size. + If drop_last is False, the last batch will be returned even if it is smaller than batch_size. + + Parameters: + source (BaseNode[T]): The source node to batch the data from. + batch_size (int): The size of the batch. + drop_last (bool): Whether to drop the last batch if it is smaller than batch_size. Default is True. + + Attributes: + SOURCE_KEY (str): The key for the source node in the state dict. + """ + SOURCE_KEY = "source" def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): diff --git a/torchdata/nodes/loader.py b/torchdata/nodes/loader.py index 93a033a5c..fe0dacb56 100644 --- a/torchdata/nodes/loader.py +++ b/torchdata/nodes/loader.py @@ -4,6 +4,16 @@ class Loader(Generic[T]): + """Wraps the root node (iterator) and provides a stateful iterable interface. + + The state of the last-returned iterator is returned by the state_dict() method, and can be + loaded using the load_state_dict() method. + + Parameters: + root (BaseNode[T]): The root node of the data pipeline. + restart_on_stop_iteration (bool): Whether to restart the iterator when it reaches the end. Default is True + """ + def __init__(self, root: BaseNode[T], restart_on_stop_iteration: bool = True): super().__init__() self.root = root @@ -49,6 +59,16 @@ def state_dict(self) -> Dict[str, Any]: class LoaderIterator(BaseNode[T]): + """An iterator class that wraps a root node and works with the Loader class. + + The LoaderIterator object saves state of the underlying root node, and calls reset on the root node when + the iterator is exhausted or on a reset call. We look one step ahead to determine if the iterator is exhausted. + The state of the iterator is saved in the state_dict() method, and can be loaded on reset calls. + + Parameters: + loader (Loader[T]): The loader object that contains the root node. + """ + NUM_YIELDED_KEY = "num_yielded" ROOT_KEY = "root" diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index b999cb138..1a1a6a900 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -278,6 +278,16 @@ class ParallelMapper(BaseNode[T]): If in_order is true, the iterator will return items in the order from which they arrive from source's iterator, potentially blocking even if other items are available. + + Parameters: + source (BaseNode[X]): The source node to map over. + map_fn (Callable[[X], T]): The function to apply to each item from the source node. + num_workers (int): The number of workers to use for parallel processing. + in_order (bool): Whether to return items in the order from which they arrive from. Default is True. + method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread". + multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None. + max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None. + snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1. """ def __init__( diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py index cc9e45caf..5acf2021e 100644 --- a/torchdata/nodes/pin_memory.py +++ b/torchdata/nodes/pin_memory.py @@ -31,11 +31,12 @@ def _pin_memory_loop( device_id: Union[int, str], device: Optional[str], ): - # this is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop - # to remove the index tuples + """This is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop + to remove the index tuples. - # This setting is thread local, and prevents the copy in pin_memory from - # consuming all CPU cores. + This setting is thread local, and prevents the copy in pin_memory from + consuming all CPU cores. + """ idx = MonotonicIndex() @@ -94,6 +95,17 @@ def _put( class PinMemory(BaseNode[T]): + """Pins the data of the underlying node to a device. This is backed by torch.utils.data._utils.pin_memory._pin_memory_loop. + + Parameters: + source (BaseNode[T]): The source node to pin the data from. + pin_memory_device (str): The device to pin the data to. Default is "". + snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is + 1, which means that the state of the source node will be snapshotted after every item. If set + to a higher value, the state of the source node will be snapshotted after every snapshot_frequency + items. + """ + def __init__( self, source: BaseNode[T], diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py index ed80f54fb..2f22d33fa 100644 --- a/torchdata/nodes/prefetch.py +++ b/torchdata/nodes/prefetch.py @@ -14,13 +14,23 @@ class Prefetcher(BaseNode[T]): + """Prefetches data from the source node and stores it in a queue. + + Parameters: + source (BaseNode[T]): The source node to prefetch data from. + prefetch_factor (int): The number of items to prefetch ahead of time. + snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is + 1, which means that the state of the source node will be snapshotted after every item. If set + to a higher value, the state of the source node will be snapshotted after every snapshot_frequency + items. + """ + def __init__(self, source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1): super().__init__() self.source = source self.prefetch_factor = prefetch_factor self.snapshot_frequency = snapshot_frequency self._it: Optional[_SingleThreadedMapper[T]] = None - self._iter_for_state_dict: bool = False def reset(self, initial_state: Optional[Dict[str, Any]] = None): super().reset(initial_state) diff --git a/torchdata/nodes/samplers/multi_node_weighted_sampler.py b/torchdata/nodes/samplers/multi_node_weighted_sampler.py index 0dcac5b22..c89ae4942 100644 --- a/torchdata/nodes/samplers/multi_node_weighted_sampler.py +++ b/torchdata/nodes/samplers/multi_node_weighted_sampler.py @@ -201,6 +201,23 @@ def get_state(self) -> Dict[str, Any]: class _WeightedSampler: + """A weighted sampler that samples from a list of weights. + + The class implements the state using the following keys: + - g_state: The state of the random number generator. + - g_rank_state: The state of the random number generator for the rank. + - offset: The offset of the batch of indices. + + Parameters: + weights (Dict[str, float]): A dictionary of weights for each source node. + seed (int): The seed for the random number generator. + rank (int): The rank of the current process. + world_size (int): The world size of the distributed environment. + random_tensor_batch_size (int): Generating random numbers in batches is faster than individually. + This setting controls the batch size, but is invisible to users and shouldn't need to be tuned. Default is 1000. + initial_state (Optional[Dict[str, Any]]): The initial state of the sampler. Default is None. + """ + def __init__( self, weights: Dict[str, float], @@ -208,7 +225,7 @@ def __init__( rank: int, world_size: int, epoch: int, - randon_tensor_batch_size: int = 1000, + random_tensor_batch_size: int = 1000, initial_state: Optional[Dict[str, Any]] = None, ): _names, _weights = [], [] @@ -219,7 +236,7 @@ def __init__( self.names = _names self.weights = torch.tensor(_weights, dtype=torch.float64) - self.randon_tensor_batch_size = randon_tensor_batch_size + self.random_tensor_batch_size = random_tensor_batch_size self._g = torch.Generator() self._g_rank = torch.Generator() @@ -241,7 +258,7 @@ def _get_batch_of_indices(self) -> list[int]: self._g_snapshot = self._g.get_state() return torch.multinomial( self.weights, - num_samples=self.randon_tensor_batch_size, + num_samples=self.random_tensor_batch_size, replacement=True, generator=self._g, ).tolist() diff --git a/torchdata/nodes/snapshot_store.py b/torchdata/nodes/snapshot_store.py index c5c2189ce..52f751089 100644 --- a/torchdata/nodes/snapshot_store.py +++ b/torchdata/nodes/snapshot_store.py @@ -84,7 +84,7 @@ def get_initial_snapshot(self, thread: threading.Thread, timeout: float = 60.0) # thread may inadvertently report "is_alive()==False" break - if isinstance(snapshot, ExceptionWrapper): + if snapshot is not None and isinstance(snapshot, ExceptionWrapper): snapshot.reraise() if snapshot is None or ver != self.SNAPSHOT_INIT_VERSION: