Skip to content

Commit

Permalink
Add epoch based seed to MultiNodeWeightedSampler (#1369)
Browse files Browse the repository at this point in the history
* initial commit

* unit tests

* remove extra function

* Add an epoch updater

* move multi-epoch support inside mixer class

* replace exact epoch with rng, includes PR#1377 for fixing prefetcher initial state

* clean up

* clean up some more

* add epoch seed based on rank, remove rng for epoch seeds, extend a test case

* loop over epochs to generate different seeds

* remove extra function, extra comment

* rename epoch_random_seed to epoch, remove comments

* Update docstring

* fix state key name
  • Loading branch information
divyanshk authored Dec 3, 2024
1 parent b3ab645 commit 9468e00
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 34 deletions.
87 changes: 80 additions & 7 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import itertools
from enum import unique

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
Expand Down Expand Up @@ -38,7 +37,7 @@ def test_torchdata_nodes_imports(self) -> None:
except ImportError:
self.fail("MultiNodeWeightedSampler or StopCriteria failed to import")

def _setup_multi_node_wighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher:
def _setup_multi_node_weighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher:

datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
weights = {f"ds{i}": weights_fn(i) for i in range(num_datasets)}
Expand Down Expand Up @@ -95,7 +94,7 @@ def test_multi_node_weighted_batch_sampler_zero_weights(
)

def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
mixer = self._setup_multi_node_wighted_sampler(
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
Expand All @@ -116,7 +115,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None:
mixer.reset()

def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
mixer = self._setup_multi_node_wighted_sampler(
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
Expand All @@ -140,7 +139,7 @@ def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None:
mixer.reset()

def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
mixer = self._setup_multi_node_wighted_sampler(
mixer = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
Expand All @@ -165,7 +164,7 @@ def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
],
)
)
def test_save_load_state_stateful(self, midpoint: int, stop_criteria: str):
def test_save_load_state_mixer(self, midpoint: int, stop_criteria: str):
mixer = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
run_test_save_load_state(self, mixer, midpoint)

Expand All @@ -183,7 +182,7 @@ def test_multi_node_weighted_large_sample_size(self, midpoint, stop_criteria) ->
num_samples = 1500
num_datasets = 5

mixer = self._setup_multi_node_wighted_sampler(
mixer = self._setup_multi_node_weighted_sampler(
num_samples,
num_datasets,
self._weights_fn,
Expand All @@ -210,3 +209,77 @@ def test_multi_node_weighted_batch_sampler_results_for_ranks(self):
if results not in unique_results:
unique_results.append(results)
self.assertEqual(unique_results, global_results)

def test_multi_node_weighted_batch_sampler_results_for_multiple_epochs(self):
# Check the mixer node
mixer = MultiNodeWeightedSampler(
self.datasets,
self.weights,
)

overall_results = []
for i in range(self._num_epochs):
results = list(mixer)
overall_results.append(results)
mixer.reset()

unique_results = []
for results in overall_results:
if results not in unique_results:
unique_results.append(results)

self.assertEqual(unique_results, overall_results)

# Check mixer along with Prefetcher node
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
)

overall_results = []
for i in range(self._num_epochs):
results = list(node)
overall_results.append(results)
node.reset()

unique_results = []
for results in overall_results:
if results not in unique_results:
unique_results.append(results)

self.assertEqual(unique_results, overall_results)

@parameterized.expand(
itertools.product(
[1, 4, 7],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_save_load_state_mixer_over_multiple_epochs(self, midpoint: int, stop_criteria: str):
node = MultiNodeWeightedSampler(self.datasets, self.weights, stop_criteria)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand(
itertools.product(
[1, 4, 7],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_save_load_state_mixer_over_multiple_epochs_with_prefetcher(self, midpoint: int, stop_criteria: str):
node = self._setup_multi_node_weighted_sampler(
self._num_samples,
self._num_datasets,
self._weights_fn,
stop_criteria=stop_criteria,
)
run_test_save_load_state(self, node, midpoint)
1 change: 0 additions & 1 deletion torchdata/nodes/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class SamplerWrapper(BaseNode[T]):
:param epoch_updater: Optional[Callable[[int], int]] = None - callback to update epoch at start of new iteration. It's called at the beginning of each iterator request, except the first one.
"""

NEXT_EPOCH_KEY = "_next_epoch"
NUM_YIELDED_KEY = "_num_yielded"
EPOCH_KEY = "_epoch"
SAMPLER_KEY = "_sampler"
Expand Down
59 changes: 35 additions & 24 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ class MultiNodeWeightedSampler(BaseNode[T]):
weights for sampling. `seed` is used to initialize the random number generator.
The node implements the state using the following keys:
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- EPOCH_KEY: An epoch counter used to initialize the random number generator.
- NUM_YIELDED_KEY: The number of items yielded.
- WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.
We support multiple stopping criteria:
- CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets
Expand All @@ -49,9 +50,10 @@ class MultiNodeWeightedSampler(BaseNode[T]):
"""

DATASET_NODE_STATES_KEY = "dataset_node_states"
DATASETS_EXHAUSTED_KEY = "datasets_exhausted"
EPOCH_KEY = "epoch"
NUM_YIELDED_KEY = "num_yielded"
WEIGHTED_SAMPLER_STATE_KEY = "weighted_sampler_state"
DATASETS_EXHAUSTED_KEY = "datasets_exhausted"

def __init__(
self,
Expand All @@ -63,19 +65,24 @@ def __init__(
seed: int = 0,
) -> None:
super().__init__()

self.source_nodes = source_nodes
self.weights = weights
self.stop_criteria = stop_criteria
self.dataset_names = list(self.source_nodes.keys())
self._num_yielded = 0
self._started = False
self.seed = seed

# Setup rank and world size
if rank is None or world_size is None:
self.rank, self.world_size = get_rank_and_world_size()
else:
self.rank = rank
self.world_size = world_size

self._epoch = 0

self._validate()

def _validate(self) -> None:
Expand Down Expand Up @@ -105,31 +112,35 @@ def _validate(self) -> None:

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)

if initial_state is not None:
self._num_yielded = initial_state[self.NUM_YIELDED_KEY]
self._weighted_sampler = _WeightedSampler(
weights=self.weights,
seed=self.seed,
rank=self.rank,
world_size=self.world_size,
initial_state=initial_state[self.WEIGHTED_SAMPLER_STATE_KEY],
)
self._epoch = initial_state[self.EPOCH_KEY]
self._weighted_sampler = self._get_new_weighted_sampler(initial_state)
self._datasets_exhausted = initial_state[self.DATASETS_EXHAUSTED_KEY]
for k in self.dataset_names:
self.source_nodes[k].reset(initial_state[self.DATASET_NODE_STATES_KEY][k])
else:
# Force a fresh iterator from all source nodes
self._num_yielded = 0
self._weighted_sampler = _WeightedSampler(
weights=self.weights,
seed=self.seed,
rank=self.rank,
world_size=self.world_size,
)

if self._started:
self._epoch += 1
self._weighted_sampler = self._get_new_weighted_sampler()

self._datasets_exhausted = {key: False for key in self.weights.keys()}
for k in self.dataset_names:
self.source_nodes[k].reset()
self._started = False

def _get_new_weighted_sampler(self, initial_state=None):
return _WeightedSampler(
weights=self.weights,
seed=self.seed,
rank=self.rank,
world_size=self.world_size,
epoch=self._epoch,
initial_state=(initial_state[self.WEIGHTED_SAMPLER_STATE_KEY] if initial_state is not None else None),
)

def _check_for_stop_iteration(self) -> None:
if all(self._datasets_exhausted.values()):
Expand All @@ -146,6 +157,7 @@ def _check_for_stop_iteration(self) -> None:
return

def next(self) -> T:
self._started = True
while True:
self._check_for_stop_iteration()

Expand Down Expand Up @@ -180,10 +192,11 @@ def next(self) -> T:

def get_state(self) -> Dict[str, Any]:
return {
self.NUM_YIELDED_KEY: self._num_yielded,
self.WEIGHTED_SAMPLER_STATE_KEY: self._weighted_sampler.state_dict(),
self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted),
self.DATASET_NODE_STATES_KEY: {k: self.source_nodes[k].state_dict() for k in self.dataset_names},
self.EPOCH_KEY: self._epoch,
self.NUM_YIELDED_KEY: self._num_yielded,
self.WEIGHTED_SAMPLER_STATE_KEY: self._weighted_sampler.state_dict(),
}


Expand All @@ -194,6 +207,7 @@ def __init__(
seed: int,
rank: int,
world_size: int,
epoch: int,
randon_tensor_batch_size: int = 1000,
initial_state: Optional[Dict[str, Any]] = None,
):
Expand All @@ -210,14 +224,13 @@ def __init__(
self._g = torch.Generator()
self._g_rank = torch.Generator()

seed = _get_rank_seed(seed, self._g_rank, rank, world_size)
self.epoch = epoch
seed = _get_rank_seed(seed, self._g_rank, rank, world_size, self.epoch)
self._g.manual_seed(seed)

self._g_snapshot = self._g.get_state()
self._g_rank_snapshot = self._g_rank.get_state()
if initial_state is not None:
self._g.set_state(initial_state["g_state"])
self._g_rank.set_state(initial_state["g_rank_state"])
self._offset = initial_state["offset"]
else:
self._offset = 0
Expand All @@ -226,7 +239,6 @@ def __init__(

def _get_batch_of_indices(self) -> list[int]:
self._g_snapshot = self._g.get_state()
self._g_rank_snapshot = self._g_rank.get_state()
return torch.multinomial(
self.weights,
num_samples=self.randon_tensor_batch_size,
Expand All @@ -248,6 +260,5 @@ def __next__(self):
def state_dict(self):
return {
"g_state": self._g_snapshot,
"g_rank_state": self._g_rank_snapshot,
"offset": self._offset,
}
4 changes: 2 additions & 2 deletions torchdata/nodes/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torch.distributed as dist


def _get_rank_seed(seed: int, generator_rank: torch.Generator, rank: int, world_size: int) -> int:
def _get_rank_seed(seed: int, generator_rank: torch.Generator, rank: int, world_size: int, epoch: int) -> int:
generator_rank.manual_seed(seed * world_size + rank)
return int(torch.randint(0, 2 ** 32 - 1, size=(1,), generator=generator_rank).item())
return int(torch.randint(0, 2 ** 32 - 1, size=(epoch + 1,), generator=generator_rank)[-1].item())


def get_rank_and_world_size() -> tuple[int, int]:
Expand Down

0 comments on commit 9468e00

Please sign in to comment.