Skip to content

Commit

Permalink
make StopCriteria available in torchdata.nodes import (#1376)
Browse files Browse the repository at this point in the history
* make StopCriteria import with torchdata.nodes

* use try except instead
  • Loading branch information
divyanshk authored Nov 26, 2024
1 parent 2a992ce commit 77bf3d1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def setUp(self) -> None:
}
self.weights = {f"ds{i}": self._weights_fn(i) for i in range(self._num_datasets)}

def test_torchdata_nodes_imports(self) -> None:
try:
from torchdata.nodes import MultiNodeWeightedSampler, StopCriteria # noqa
except ImportError:
self.fail("MultiNodeWeightedSampler or StopCriteria failed to import")

def _setup_multi_node_wighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher:

datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
Expand Down
2 changes: 2 additions & 0 deletions torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .pin_memory import PinMemory
from .prefetch import Prefetcher
from .samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
from .samplers.stop_criteria import StopCriteria
from .types import Stateful


Expand All @@ -28,6 +29,7 @@
"Prefetcher",
"SamplerWrapper",
"Stateful",
"StopCriteria",
"T",
]

Expand Down

0 comments on commit 77bf3d1

Please sign in to comment.