Skip to content

Commit

Permalink
Adding doc strings ahead of new release (#1378)
Browse files Browse the repository at this point in the history
* doc strings

* tweak unit tests

* more doc strings

* fix linter

* tweaks

* address comments, fix spellings

* Add doc strings for pin_memory.py

* Doc strings for map.py
  • Loading branch information
divyanshk authored Dec 6, 2024
1 parent 72c4d12 commit 6557f22
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 62 deletions.
91 changes: 39 additions & 52 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_
return Prefetcher(node, prefetch_factor=3)

def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None:
"""
Validation should fail if the keys of source_nodes and weights are not the same
"""
"""Test validation logic for MultiNodeWeightedSampler if the keys of source_nodes and weights are not the same"""
with self.assertRaisesRegex(
ValueError,
"keys of source_nodes and weights must be the same",
Expand All @@ -60,9 +58,7 @@ def test_multi_node_weighted_sampler_weight_sampler_keys_mismatch(self) -> None:
def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape(
self,
) -> None:
"""
Validation should fail if the shape of the weights tensor is invalid
"""
"""Test validation logic for MultiNodeWeightedSampler if the shape of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets,
Expand All @@ -72,9 +68,7 @@ def test_multi_node_weighted_batch_sampler_invalid_weights_tensor_shape(
def test_multi_node_weighted_batch_sampler_negative_weights(
self,
) -> None:
"""
Validation should fail if the value of the weights tensor is invalid
"""
"""Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets,
Expand All @@ -84,16 +78,15 @@ def test_multi_node_weighted_batch_sampler_negative_weights(
def test_multi_node_weighted_batch_sampler_zero_weights(
self,
) -> None:
"""
Validation should fail if the value of the weights tensor is invalid
"""
"""Test validation logic for MultiNodeWeightedSampler if the value of the weights tensor is invalid"""
with self.assertRaisesRegex(ValueError, " weights must be a 1d sequence, non-negative, and non-zero"):
MultiNodeWeightedSampler(
self.datasets,
weights={f"ds{i}": 10 * i for i in range(self._num_datasets)},
)

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
"""Test MultiNodeWeightedSampler with stop criteria FIRST_DATASET_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -115,6 +108,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
mixer.reset()

def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
"""Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -139,6 +133,7 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
mixer.reset()

def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"""
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -154,49 +149,15 @@ def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"])
mixer.reset()

@parameterized.expand(
itertools.product(
[1, 4, 7],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_save_load_state_mixer(self, midpoint: int, stop_criteria: str):
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
run_test_save_load_state(self, mixer, midpoint)

@parameterized.expand(
itertools.product(
[100, 500, 1200],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_multi_node_weighted_large_sample_size(self, midpoint, stop_criteria) -> None:
num_samples = 1500
num_datasets = 5

mixer = self._setup_multi_node_weighted_sampler(
num_samples,
num_datasets,
self._weights_fn,
stop_criteria,
)
run_test_save_load_state(self, mixer, midpoint)

@parameterized.expand([(1, 8), (8, 32)])
def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size):
"""Test MultiNodeWeightedSampler with different rank and world size"""
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, rank=rank, world_size=world_size)
self.assertEqual(mixer.rank, rank)
self.assertEqual(mixer.world_size, world_size)

def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
"""Test MultiNodeWeightedSampler with different results for different ranks"""
world_size = 8
global_results = []
for rank in range(world_size):
Expand All @@ -211,14 +172,16 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
self.assertEqual(unique_results, global_results)

def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
# Check the mixer node
"""Test MultiNodeWeightedSampler with different results in each epoch"""

# Check for the mixer node only
mixer = MultiNodeWeightedSampler(
self.datasets,
self.weights,
)

overall_results = []
for i in range(self._num_epochs):
for _ in range(self._num_epochs):
results = list(mixer)
overall_results.append(results)
mixer.reset()
Expand All @@ -230,7 +193,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):

self.assertEqual(unique_results, overall_results)

# Check mixer along with Prefetcher node
# Check for mixer along with Prefetcher node
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
Expand All @@ -239,7 +202,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
)

overall_results = []
for i in range(self._num_epochs):
for _ in range(self._num_epochs):
results = list(node)
overall_results.append(results)
node.reset()
Expand All @@ -262,6 +225,7 @@ def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
)
)
def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
"""Test MultiNodeWeightedSampler with saving and loading of state across multiple epochs"""
node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
run_test_save_load_state(self, node, midpoint)

Expand All @@ -283,3 +247,26 @@ def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoi
stop_criteria=stop_criteria,
)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand(
itertools.product(
[100, 500, 1200],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, stop_criteria) -> None:
"""Test MultiNodeWeightedSampler (larger sample sizes) with saving and loading of state across multiple epochs"""
num_samples = 1500
num_datasets = 5

mixer = self._setup_multi_node_weighted_sampler(
num_samples,
num_datasets,
self._weights_fn,
stop_criteria,
)
run_test_save_load_state(self, mixer, midpoint)
2 changes: 1 addition & 1 deletion test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
self.i = 0
self.num_resets += 1

def next(self) -> Iterator[Dict[str, int]]:
def next(self) -> Dict[str, int]:
if self.i == self.n:
raise StopIteration()
ret = {"i": self.i, "resets": self.num_resets}
Expand Down
14 changes: 14 additions & 0 deletions torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@


class Batcher(BaseNode[List[T]]):
"""Batcher node batches the data from the source node into batches of size batch_size.
If the source node is exhausted, it will return the batch or raise StopIteration.
If drop_last is True, the last batch will be dropped if it is smaller than batch_size.
If drop_last is False, the last batch will be returned even if it is smaller than batch_size.
Parameters:
source (BaseNode[T]): The source node to batch the data from.
batch_size (int): The size of the batch.
drop_last (bool): Whether to drop the last batch if it is smaller than batch_size. Default is True.
Attributes:
SOURCE_KEY (str): The key for the source node in the state dict.
"""

SOURCE_KEY = "source"

def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True):
Expand Down
20 changes: 20 additions & 0 deletions torchdata/nodes/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@


class Loader(Generic[T]):
"""Wraps the root node (iterator) and provides a stateful iterable interface.
The state of the last-returned iterator is returned by the state_dict() method, and can be
loaded using the load_state_dict() method.
Parameters:
root (BaseNode[T]): The root node of the data pipeline.
restart_on_stop_iteration (bool): Whether to restart the iterator when it reaches the end. Default is True
"""

def __init__(self, root: BaseNode[T], restart_on_stop_iteration: bool = True):
super().__init__()
self.root = root
Expand Down Expand Up @@ -49,6 +59,16 @@ def state_dict(self) -> Dict[str, Any]:


class LoaderIterator(BaseNode[T]):
"""An iterator class that wraps a root node and works with the Loader class.
The LoaderIterator object saves state of the underlying root node, and calls reset on the root node when
the iterator is exhausted or on a reset call. We look one step ahead to determine if the iterator is exhausted.
The state of the iterator is saved in the state_dict() method, and can be loaded on reset calls.
Parameters:
loader (Loader[T]): The loader object that contains the root node.
"""

NUM_YIELDED_KEY = "num_yielded"
ROOT_KEY = "root"

Expand Down
10 changes: 10 additions & 0 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,16 @@ class ParallelMapper(BaseNode[T]):
If in_order is true, the iterator will return items in the order from which they arrive
from source's iterator, potentially blocking even if other items are available.
Parameters:
source (BaseNode[X]): The source node to map over.
map_fn (Callable[[X], T]): The function to apply to each item from the source node.
num_workers (int): The number of workers to use for parallel processing.
in_order (bool): Whether to return items in the order from which they arrive from. Default is True.
method (Literal["thread", "process"]): The method to use for parallel processing. Default is "thread".
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
"""

def __init__(
Expand Down
20 changes: 16 additions & 4 deletions torchdata/nodes/pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ def _pin_memory_loop(
device_id: Union[int, str],
device: Optional[str],
):
# this is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop
# to remove the index tuples
"""This is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop
to remove the index tuples.
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
This setting is thread local, and prevents the copy in pin_memory from
consuming all CPU cores.
"""

idx = MonotonicIndex()

Expand Down Expand Up @@ -94,6 +95,17 @@ def _put(


class PinMemory(BaseNode[T]):
"""Pins the data of the underlying node to a device. This is backed by torch.utils.data._utils.pin_memory._pin_memory_loop.
Parameters:
source (BaseNode[T]): The source node to pin the data from.
pin_memory_device (str): The device to pin the data to. Default is "".
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is
1, which means that the state of the source node will be snapshotted after every item. If set
to a higher value, the state of the source node will be snapshotted after every snapshot_frequency
items.
"""

def __init__(
self,
source: BaseNode[T],
Expand Down
12 changes: 11 additions & 1 deletion torchdata/nodes/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@


class Prefetcher(BaseNode[T]):
"""Prefetches data from the source node and stores it in a queue.
Parameters:
source (BaseNode[T]): The source node to prefetch data from.
prefetch_factor (int): The number of items to prefetch ahead of time.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is
1, which means that the state of the source node will be snapshotted after every item. If set
to a higher value, the state of the source node will be snapshotted after every snapshot_frequency
items.
"""

def __init__(self, source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1):
super().__init__()
self.source = source
self.prefetch_factor = prefetch_factor
self.snapshot_frequency = snapshot_frequency
self._it: Optional[_SingleThreadedMapper[T]] = None
self._iter_for_state_dict: bool = False

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
Expand Down
23 changes: 20 additions & 3 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,31 @@ def get_state(self) -> Dict[str, Any]:


class _WeightedSampler:
"""A weighted sampler that samples from a list of weights.
The class implements the state using the following keys:
- g_state: The state of the random number generator.
- g_rank_state: The state of the random number generator for the rank.
- offset: The offset of the batch of indices.
Parameters:
weights (Dict[str, float]): A dictionary of weights for each source node.
seed (int): The seed for the random number generator.
rank (int): The rank of the current process.
world_size (int): The world size of the distributed environment.
random_tensor_batch_size (int): Generating random numbers in batches is faster than individually.
This setting controls the batch size, but is invisible to users and shouldn't need to be tuned. Default is 1000.
initial_state (Optional[Dict[str, Any]]): The initial state of the sampler. Default is None.
"""

def __init__(
self,
weights: Dict[str, float],
seed: int,
rank: int,
world_size: int,
epoch: int,
randon_tensor_batch_size: int = 1000,
random_tensor_batch_size: int = 1000,
initial_state: Optional[Dict[str, Any]] = None,
):
_names, _weights = [], []
Expand All @@ -219,7 +236,7 @@ def __init__(
self.names = _names
self.weights = torch.tensor(_weights, dtype=torch.float64)

self.randon_tensor_batch_size = randon_tensor_batch_size
self.random_tensor_batch_size = random_tensor_batch_size

self._g = torch.Generator()
self._g_rank = torch.Generator()
Expand All @@ -241,7 +258,7 @@ def _get_batch_of_indices(self) -> list[int]:
self._g_snapshot = self._g.get_state()
return torch.multinomial(
self.weights,
num_samples=self.randon_tensor_batch_size,
num_samples=self.random_tensor_batch_size,
replacement=True,
generator=self._g,
).tolist()
Expand Down
2 changes: 1 addition & 1 deletion torchdata/nodes/snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_initial_snapshot(self, thread: threading.Thread, timeout: float = 60.0)
# thread may inadvertently report "is_alive()==False"
break

if isinstance(snapshot, ExceptionWrapper):
if snapshot is not None and isinstance(snapshot, ExceptionWrapper):
snapshot.reraise()

if snapshot is None or ver != self.SNAPSHOT_INIT_VERSION:
Expand Down

0 comments on commit 6557f22

Please sign in to comment.