From b3ab6454aafe44ccba8c904d5e7d1aac023bcf9f Mon Sep 17 00:00:00 2001 From: Andrew Ho Date: Mon, 2 Dec 2024 12:34:12 -0500 Subject: [PATCH] initial state fix for pin_memory and SingleThreadedMapper (#1377) Fixes initial state issues with nodes that use SingleThreadedMapper (pin_memory and prefetch) by passing an initial state_dict during initialization from source node. Makes snapshot queue thread-safe and more robust to worker errors --- test/nodes/test_map.py | 6 +- test/nodes/test_pin_memory.py | 4 +- test/nodes/test_prefetch.py | 6 +- test/nodes/test_snapshot_store.py | 119 ++++++++++++++++++++--------- test/nodes/utils.py | 30 ++++++++ torchdata/nodes/_populate_queue.py | 12 ++- torchdata/nodes/loader.py | 14 ++++ torchdata/nodes/map.py | 29 ++++--- torchdata/nodes/pin_memory.py | 12 ++- torchdata/nodes/snapshot_store.py | 69 ++++++++++++++--- 10 files changed, 227 insertions(+), 74 deletions(-) diff --git a/test/nodes/test_map.py b/test/nodes/test_map.py index 7705580ba..7caacd428 100644 --- a/test/nodes/test_map.py +++ b/test/nodes/test_map.py @@ -17,7 +17,7 @@ from torchdata.nodes.pin_memory import PinMemory from torchdata.nodes.prefetch import Prefetcher -from .utils import MockSource, RandomSleepUdf, run_test_save_load_state, udf_raises +from .utils import MockSource, RandomSleepUdf, run_test_save_load_state, StatefulRangeNode, udf_raises class TestMap(TestCase): @@ -120,7 +120,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr method = "thread" batch_size = 6 n = 80 - src = MockSource(num_samples=n) + src = StatefulRangeNode(n=n) node = Batcher(src, batch_size=batch_size, drop_last=False) node = ParallelMapper( node, @@ -145,7 +145,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f batch_size = 6 n = 80 multiprocessing_context = None if IS_WINDOWS else "forkserver" - src = MockSource(num_samples=n) + src = StatefulRangeNode(n=n) node = Batcher(src, batch_size=batch_size, drop_last=False) node = ParallelMapper( node, diff --git a/test/nodes/test_pin_memory.py b/test/nodes/test_pin_memory.py index 01f0c0dac..b7262518e 100644 --- a/test/nodes/test_pin_memory.py +++ b/test/nodes/test_pin_memory.py @@ -18,7 +18,7 @@ from torchdata.nodes.pin_memory import PinMemory from torchdata.nodes.prefetch import Prefetcher -from .utils import Collate, IterInitError, MockSource, run_test_save_load_state +from .utils import Collate, IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @@ -70,7 +70,7 @@ def test_iter_init_error(self): def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): batch_size = 6 n = 200 - node = MockSource(num_samples=n) + node = StatefulRangeNode(n=n) node = Batcher(node, batch_size=batch_size, drop_last=False) node = Mapper(node, Collate()) node = PinMemory(node, snapshot_frequency=snapshot_frequency) diff --git a/test/nodes/test_prefetch.py b/test/nodes/test_prefetch.py index 77505436d..79e117175 100644 --- a/test/nodes/test_prefetch.py +++ b/test/nodes/test_prefetch.py @@ -9,10 +9,12 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.adapters import IterableWrapper from torchdata.nodes.batch import Batcher +from torchdata.nodes.loader import Loader from torchdata.nodes.prefetch import Prefetcher -from .utils import IterInitError, MockSource, run_test_save_load_state +from .utils import IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode class TestPrefetcher(TestCase): @@ -44,7 +46,7 @@ def test_iter_init_error(self): def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): batch_size = 6 n = 200 - src = MockSource(num_samples=n) + src = StatefulRangeNode(n=n) node = Batcher(src, batch_size=batch_size, drop_last=False) node = Prefetcher(node, prefetch_factor=8, snapshot_frequency=snapshot_frequency) run_test_save_load_state(self, node, midpoint) diff --git a/test/nodes/test_snapshot_store.py b/test/nodes/test_snapshot_store.py index 7f2c1949b..41f152575 100644 --- a/test/nodes/test_snapshot_store.py +++ b/test/nodes/test_snapshot_store.py @@ -4,46 +4,89 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import threading +import time + from torch.testing._internal.common_utils import TestCase -from torchdata.nodes.snapshot_store import DequeSnapshotStore +from torchdata.nodes.constants import QUEUE_TIMEOUT +from torchdata.nodes.exception_wrapper import StartupExceptionWrapper +from torchdata.nodes.snapshot_store import QueueSnapshotStore class TestDequeSnapshotStore(TestCase): def test_snapshot_store(self) -> None: - store = DequeSnapshotStore() - store.append({"a": 1}, 0) - store.append({"a": 2}, 10) - - self.assertEqual(len(store._deque), 2) - - val = store.pop_version(0) - self.assertEqual(val, {"a": 1}) - self.assertEqual(len(store._deque), 1) - val = store.pop_version(1) - self.assertIsNone(val) - self.assertEqual(len(store._deque), 1) - val = store.pop_version(7) - self.assertIsNone(val) - self.assertEqual(len(store._deque), 1) - val = store.pop_version(10) - self.assertEqual(val, {"a": 2}) - self.assertEqual(len(store._deque), 0) - - val = store.pop_version(11) - self.assertIsNone(val) - self.assertEqual(len(store._deque), 0) - - with self.assertRaisesRegex(ValueError, "is not strictly greater than"): - store.append({"a": 3}, 3) - - self.assertEqual(len(store._deque), 0) - - with self.assertRaisesRegex(ValueError, "is not strictly greater than"): - store.append({"a": 4}, 10) - self.assertEqual(len(store._deque), 0) - - store.append({"a": 4}, 11) - store.append({"a": 5}, 19) - val = store.pop_version(19) - self.assertEqual(val, {"a": 5}) - self.assertEqual(len(store._deque), 0) + for _ in range(100): + store = QueueSnapshotStore() + store.append({"a": 1}, 0) + store.append({"a": 2}, 10) + + self.assertEqual(len(store._q.queue), 2) + + val = store.pop_version(0) + self.assertEqual(val, {"a": 1}) + self.assertEqual(len(store._q.queue), 1) + val = store.pop_version(1) + self.assertIsNone(val) + self.assertEqual(len(store._q.queue), 1) + val = store.pop_version(7) + self.assertIsNone(val) + self.assertEqual(len(store._q.queue), 1) + val = store.pop_version(10) + self.assertEqual(val, {"a": 2}) + self.assertEqual(len(store._q.queue), 0) + + val = store.pop_version(11) + self.assertIsNone(val) + self.assertEqual(len(store._q.queue), 0) + + with self.assertRaisesRegex(ValueError, "is not strictly greater than"): + store.append({"a": 3}, 3) + + self.assertEqual(len(store._q.queue), 0) + + with self.assertRaisesRegex(ValueError, "is not strictly greater than"): + store.append({"a": 4}, 10) + self.assertEqual(len(store._q.queue), 0) + + store.append({"a": 4}, 11) + store.append({"a": 5}, 19) + val = store.pop_version(19) + self.assertEqual(val, {"a": 5}) + self.assertEqual(len(store._q.queue), 0) + + def test_init_error(self) -> None: + store = QueueSnapshotStore() + sleep_time = 0.1 + thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time)) + thread.start() + with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"): + store.get_initial_snapshot(thread, sleep_time) + thread.join() + + def test_timeout_error(self) -> None: + store = QueueSnapshotStore() + sleep_time = 0.1 + thread = threading.Thread(target=time.sleep, args=(sleep_time,)) + thread.start() + with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"): + store.get_initial_snapshot(thread, sleep_time * 0.1) + thread.join() + + def test_thread_dead_error(self) -> None: + # Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards + store = QueueSnapshotStore() + thread = threading.Thread(target=time.sleep, args=(QUEUE_TIMEOUT * 3.0,)) + thread.start() + with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"): + store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 4.0) + thread.join() + + +def _worker_init_error(store, sleep_time): + # time.sleep(0.1 * sleep_time) + try: + raise RuntimeError("Test Startup Exception") + except Exception as e: + e = StartupExceptionWrapper(where="_worker_init_error") + store.append_initial_snapshot(e) + time.sleep(sleep_time) diff --git a/test/nodes/utils.py b/test/nodes/utils.py index a23028110..857db9396 100644 --- a/test/nodes/utils.py +++ b/test/nodes/utils.py @@ -118,6 +118,36 @@ def load_state_dict(self, state_dict: Dict[str, Any]): self._next_start = state_dict["_num_yielded"] +class StatefulRangeNode(BaseNode[Dict[str, int]]): + def __init__(self, n: int) -> None: + super().__init__() + self.n = n + self.i = 0 + self.num_resets = 0 + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + if initial_state is not None: + self.i = initial_state["i"] + self.num_resets = initial_state["num_resets"] + else: + self.i = 0 + self.num_resets += 1 + + def next(self) -> Iterator[Dict[str, int]]: + if self.i == self.n: + raise StopIteration() + ret = {"i": self.i, "resets": self.num_resets} + self.i += 1 + return ret + + def get_state(self) -> Dict[str, Any]: + return { + "i": self.i, + "num_resets": self.num_resets, + } + + def run_test_save_load_state(test, node: BaseNode, midpoint: int): ############################## # Generate initial, midpoint, and end state_dict's diff --git a/torchdata/nodes/_populate_queue.py b/torchdata/nodes/_populate_queue.py index a3329baf9..cf22a4edc 100644 --- a/torchdata/nodes/_populate_queue.py +++ b/torchdata/nodes/_populate_queue.py @@ -46,7 +46,11 @@ def _populate_queue( # Include a monotonic index starting from 0 to each item in the queue idx = MonotonicIndex() - def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): + def _put( + item, + block: bool = True, + snapshot: Optional[Union[Dict[str, Any], StartupExceptionWrapper]] = None, + ): _idx = idx.get() if snapshot: snapshot_store.append(snapshot=snapshot, version=_idx) @@ -56,10 +60,10 @@ def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): assert ( isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" - src_iter = iter(source) + snapshot_store.append_initial_snapshot(snapshot=source.state_dict()) except Exception: e = StartupExceptionWrapper(where="in _populate_queue startup for device") - _put(e, block=False) + snapshot_store.append_initial_snapshot(snapshot=e) return yielded = 0 @@ -67,7 +71,7 @@ def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): if not semaphore.acquire(blocking=True, timeout=QUEUE_TIMEOUT): continue try: - item = next(src_iter) # FIXME: This may hang! + item = next(source) # FIXME: This may hang! yielded += 1 snapshot = None if snapshot_frequency > 0 and yielded % snapshot_frequency == 0: diff --git a/torchdata/nodes/loader.py b/torchdata/nodes/loader.py index ea182d618..93a033a5c 100644 --- a/torchdata/nodes/loader.py +++ b/torchdata/nodes/loader.py @@ -10,10 +10,23 @@ def __init__(self, root: BaseNode[T], restart_on_stop_iteration: bool = True): self.restart_on_stop_iteration = restart_on_stop_iteration self._next_iter_state_dict: Optional[Dict[str, Any]] = None self._it: Optional[LoaderIterator[T]] = None + # Tracks whether an iterator was created solely for getting a state_dict, in which case + # we don't want to reset the iterator. Consider these two cases, which should behave the same + # it = iter(loader) + # sd = loader.state_dict() # No extra __iter__ call as _it already exists + # for _ in it: ... + # -------- + # sd = loader.state_dict() # Calls __iter__ since _it is None + # it = iter(loader) # We don't want to reset the iterator here again + # for _ in it: ... + self._iter_for_state_dict: bool = False def __iter__(self): if self._it is None: self._it = LoaderIterator(self) + elif self._iter_for_state_dict: + self._iter_for_state_dict = False + return self._it # This was already pre-called to get a state dict if self._next_iter_state_dict is not None: self._it.reset(initial_state=self._next_iter_state_dict) @@ -31,6 +44,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): def state_dict(self) -> Dict[str, Any]: if self._it is None: iter(self) + self._iter_for_state_dict = True return self._it.state_dict() # type:ignore[union-attr] diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index a9d03c8bf..b999cb138 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -6,12 +6,13 @@ import queue import threading +import time from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union import torch.multiprocessing as mp from torchdata.nodes.base_node import BaseNode, T from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper -from torchdata.nodes.snapshot_store import DequeSnapshotStore, SnapshotStore +from torchdata.nodes.snapshot_store import QueueSnapshotStore, SnapshotStore from ._apply_udf import _apply_udf @@ -19,6 +20,8 @@ from .constants import QUEUE_TIMEOUT +ACK_TIMEOUT = 300 # Timeout after 5 minutes + # We define this protocol for type checking class _MultiprocessContext(Protocol): @@ -147,7 +150,7 @@ def __init__( else: self._snapshot = None self.source.reset() - self._snapshot_store = DequeSnapshotStore() + self._snapshot_store = QueueSnapshotStore() self._read_thread = threading.Thread( target=_populate_queue, @@ -192,6 +195,9 @@ def __init__( if self.in_order: self._sort_thread.start() + time.sleep(0.01) + self._snapshot = self._snapshot_store.get_initial_snapshot(thread=self._read_thread, timeout=ACK_TIMEOUT) + for i in range(fast_forward): try: next(self) @@ -250,13 +256,14 @@ def __del__(self): def _shutdown(self): self._stop.set() self._mp_stop.set() - if self._read_thread.is_alive(): + if hasattr(self, "_read_thread") and self._read_thread.is_alive(): self._read_thread.join(timeout=QUEUE_TIMEOUT * 5) - if self._sort_thread.is_alive(): + if hasattr(self, "_sort_thread") and self._sort_thread.is_alive(): self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5) - for t in self._workers: - if t.is_alive(): - t.join(timeout=QUEUE_TIMEOUT * 5) + if hasattr(self, "_workers"): + for t in self._workers: + if t.is_alive(): + t.join(timeout=QUEUE_TIMEOUT * 5) class ParallelMapper(BaseNode[T]): @@ -403,7 +410,7 @@ def __init__( else: self._snapshot = None self.source.reset() - self._snapshot_store = DequeSnapshotStore() + self._snapshot_store = QueueSnapshotStore() self._thread = threading.Thread( target=self.worker, args=( @@ -417,6 +424,10 @@ def __init__( daemon=True, ) self._thread.start() + + # Try and get initial snapshot + self._snapshot = self._snapshot_store.get_initial_snapshot(thread=self._thread, timeout=ACK_TIMEOUT) + for i in range(self._fast_forward): try: next(self) @@ -471,5 +482,5 @@ def __del__(self): def _shutdown(self): self._stop_event.set() - if self._thread.is_alive(): + if hasattr(self, "_thread") and self._thread.is_alive(): self._thread.join(timeout=QUEUE_TIMEOUT * 5) diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py index f33bfd519..cc9e45caf 100644 --- a/torchdata/nodes/pin_memory.py +++ b/torchdata/nodes/pin_memory.py @@ -39,7 +39,11 @@ def _pin_memory_loop( idx = MonotonicIndex() - def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): + def _put( + item, + block: bool = True, + snapshot: Optional[Union[Dict[str, Any], StartupExceptionWrapper]] = None, + ): _idx = idx.get() if snapshot: snapshot_store.append(snapshot=snapshot, version=_idx) @@ -61,10 +65,10 @@ def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): assert ( isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" - src_iter = iter(source) + snapshot_store.append_initial_snapshot(snapshot=source.state_dict()) except Exception: e = StartupExceptionWrapper(where=f"in _pin_memory_loop startup for device {device_id}") - _put(e, block=False) + snapshot_store.append_initial_snapshot(snapshot=e) return yielded = 0 @@ -72,7 +76,7 @@ def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None): if not semaphore.acquire(blocking=True, timeout=0.1): continue try: - item = next(src_iter) + item = next(source) item = pin_memory(item, device) yielded += 1 snapshot = None diff --git a/torchdata/nodes/snapshot_store.py b/torchdata/nodes/snapshot_store.py index e92001778..c5c2189ce 100644 --- a/torchdata/nodes/snapshot_store.py +++ b/torchdata/nodes/snapshot_store.py @@ -1,7 +1,13 @@ -from collections import deque +import queue +import threading +import time from dataclasses import dataclass from typing import Any, Optional, Protocol +from torchdata.nodes.constants import QUEUE_TIMEOUT + +from torchdata.nodes.exception_wrapper import ExceptionWrapper + @dataclass class MonotonicIndex: @@ -25,26 +31,65 @@ def append(self, snapshot: Any, version: int): def pop_version(self, version: int) -> Optional[Any]: ... + def append_initial_snapshot(self, snapshot: Any): + ... + + def get_initial_snapshot(self, thread: threading.Thread, timeout: float) -> Any: + ... + + +class QueueSnapshotStore(SnapshotStore): + """A snapshot store that uses a queue to store snapshots""" -class DequeSnapshotStore(SnapshotStore): - """A snapshot store that uses a deque to store snapshots""" + SNAPSHOT_INIT_VERSION = -1 - def __init__(self, max_size: Optional[int] = None) -> None: - self._deque: deque = deque(maxlen=max_size) - self._max_version: int = -1 + def __init__(self) -> None: + self._q: queue.Queue = queue.Queue() + self._lock = threading.Lock() + self._max_version: int = -1000 def append(self, snapshot: Any, version: int) -> None: - if version <= self._max_version: - raise ValueError(f"{version=} is not strictly greater than {self._max_version=}") - self._max_version = version - self._deque.append((version, snapshot)) + with self._lock: + if version <= self._max_version: + raise ValueError(f"{version=} is not strictly greater than {self._max_version=}") + self._max_version = version + self._q.put((version, snapshot)) def pop_version(self, version: int) -> Optional[Any]: ver, val = None, None - while self._deque and version >= self._deque[0][0]: - ver, val = self._deque.popleft() + with self._lock: + while self._q.queue and version >= self._q.queue[0][0]: + ver, val = self._q.get_nowait() if ver == version: return val else: return None + + def append_initial_snapshot(self, snapshot: Any) -> None: + self.append(snapshot, self.SNAPSHOT_INIT_VERSION) + + def get_initial_snapshot(self, thread: threading.Thread, timeout: float = 60.0) -> Any: + snapshot = None + ver = None + + ack_t0 = time.time() + while snapshot is None and time.time() - ack_t0 < timeout: + try: + ver, snapshot = self._q.get(timeout=QUEUE_TIMEOUT) + except queue.Empty: + pass + if not thread.is_alive(): + # Don't test this until after QUEUE_TIMEOUT has elapsed because + # thread may inadvertently report "is_alive()==False" + break + + if isinstance(snapshot, ExceptionWrapper): + snapshot.reraise() + + if snapshot is None or ver != self.SNAPSHOT_INIT_VERSION: + raise RuntimeError( + f"Failed to get initial snapshot after {time.time() - ack_t0} seconds! {thread.is_alive()=}, {snapshot=}, {ver=}" + ) + + return snapshot