diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py index 33ab65043..e8ca8afb8 100644 --- a/torchdata/nodes/base_node.py +++ b/torchdata/nodes/base_node.py @@ -47,7 +47,7 @@ def __iter__(self): def reset(self, initial_state: Optional[dict] = None): """Resets the iterator to the beginning, or to the state passed in by initial_state. - Reset is a good place to put expensive initialization, as it will be lazily called when next() or state_dict() is called. + Reset is a good place to put expensive initialization, as it will be lazily called when ``next()`` or ``state_dict()`` is called. Subclasses must call ``super().reset(initial_state)``. Args: @@ -57,14 +57,18 @@ def reset(self, initial_state: Optional[dict] = None): self.__initialized = True def get_state(self) -> Dict[str, Any]: - """Subclasses must implement this method, instead of state_dict(). Should only be called by BaseNode. - :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future + """Subclasses must implement this method, instead of ``state_dict()``. Should only be called by BaseNode. + + Returns: + Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future """ raise NotImplementedError(type(self)) def next(self) -> T: - """Subclasses must implement this method, instead of ``__next``. Should only be called by BaseNode. - :return: T - the next value in the sequence, or throw StopIteration + """Subclasses must implement this method, instead of ``__next__``. Should only be called by BaseNode. + + Returns: + T - the next value in the sequence, or throw StopIteration """ raise NotImplementedError(type(self)) @@ -83,7 +87,9 @@ def __next__(self): def state_dict(self) -> Dict[str, Any]: """Get a state_dict for this BaseNode. - :return: Dict[str, Any] - a state dict that may be passed to reset() at some point in the future. + + Returns: + Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future. """ try: self.__initialized diff --git a/torchdata/nodes/samplers/multi_node_weighted_sampler.py b/torchdata/nodes/samplers/multi_node_weighted_sampler.py index f4f4c7b4a..cc52fa92e 100644 --- a/torchdata/nodes/samplers/multi_node_weighted_sampler.py +++ b/torchdata/nodes/samplers/multi_node_weighted_sampler.py @@ -24,6 +24,7 @@ class MultiNodeWeightedSampler(BaseNode[T]): weights for sampling. `seed` is used to initialize the random number generator. The node implements the state using the following keys: + - DATASET_NODE_STATES_KEY: A dictionary of states for each source node. - DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted. - EPOCH_KEY: An epoch counter used to initialize the random number generator. @@ -31,6 +32,7 @@ class MultiNodeWeightedSampler(BaseNode[T]): - WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler. We support multiple stopping criteria: + - CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets are exhausted. This is the default behavior. - FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. - ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted. @@ -203,6 +205,7 @@ class _WeightedSampler: """A weighted sampler that samples from a list of weights. The class implements the state using the following keys: + - g_state: The state of the random number generator. - g_rank_state: The state of the random number generator for the rank. - offset: The offset of the batch of indices.