From 77bf3d18a74853d0b72998040f7df6739c5fa161 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Tue, 26 Nov 2024 14:19:40 -0800 Subject: [PATCH] make StopCriteria available in torchdata.nodes import (#1376) * make StopCriteria import with torchdata.nodes * use try except instead --- test/nodes/test_multi_node_weighted_sampler.py | 6 ++++++ torchdata/nodes/__init__.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index f54907017..fd49c0ef7 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -32,6 +32,12 @@ def setUp(self) -> None: } self.weights = {f"ds{i}": self._weights_fn(i) for i in range(self._num_datasets)} + def test_torchdata_nodes_imports(self) -> None: + try: + from torchdata.nodes import MultiNodeWeightedSampler, StopCriteria # noqa + except ImportError: + self.fail("MultiNodeWeightedSampler or StopCriteria failed to import") + def _setup_multi_node_wighted_sampler(self, num_samples, num_datasets, weights_fn, stop_criteria) -> Prefetcher: datasets = {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)} diff --git a/torchdata/nodes/__init__.py b/torchdata/nodes/__init__.py index 0cd98c86f..2f8d1f287 100644 --- a/torchdata/nodes/__init__.py +++ b/torchdata/nodes/__init__.py @@ -12,6 +12,7 @@ from .pin_memory import PinMemory from .prefetch import Prefetcher from .samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler +from .samplers.stop_criteria import StopCriteria from .types import Stateful @@ -28,6 +29,7 @@ "Prefetcher", "SamplerWrapper", "Stateful", + "StopCriteria", "T", ]