Skip to content

Commit

Permalink
initial state fix for pin_memory and SingleThreadedMapper (#1377)
Browse files Browse the repository at this point in the history
Fixes initial state issues with nodes that use SingleThreadedMapper (pin_memory and prefetch) by passing an initial state_dict during initialization from source node. Makes snapshot queue thread-safe and more robust to worker errors
  • Loading branch information
andrewkho authored Dec 2, 2024
1 parent 77bf3d1 commit b3ab645
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 74 deletions.
6 changes: 3 additions & 3 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchdata.nodes.pin_memory import PinMemory
from torchdata.nodes.prefetch import Prefetcher

from .utils import MockSource, RandomSleepUdf, run_test_save_load_state, udf_raises
from .utils import MockSource, RandomSleepUdf, run_test_save_load_state, StatefulRangeNode, udf_raises


class TestMap(TestCase):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
method = "thread"
batch_size = 6
n = 80
src = MockSource(num_samples=n)
src = StatefulRangeNode(n=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
node,
Expand All @@ -145,7 +145,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
src = MockSource(num_samples=n)
src = StatefulRangeNode(n=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
node,
Expand Down
4 changes: 2 additions & 2 deletions test/nodes/test_pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchdata.nodes.pin_memory import PinMemory
from torchdata.nodes.prefetch import Prefetcher

from .utils import Collate, IterInitError, MockSource, run_test_save_load_state
from .utils import Collate, IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode


@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_iter_init_error(self):
def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int):
batch_size = 6
n = 200
node = MockSource(num_samples=n)
node = StatefulRangeNode(n=n)
node = Batcher(node, batch_size=batch_size, drop_last=False)
node = Mapper(node, Collate())
node = PinMemory(node, snapshot_frequency=snapshot_frequency)
Expand Down
6 changes: 4 additions & 2 deletions test/nodes/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.batch import Batcher
from torchdata.nodes.loader import Loader
from torchdata.nodes.prefetch import Prefetcher

from .utils import IterInitError, MockSource, run_test_save_load_state
from .utils import IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode


class TestPrefetcher(TestCase):
Expand Down Expand Up @@ -44,7 +46,7 @@ def test_iter_init_error(self):
def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int):
batch_size = 6
n = 200
src = MockSource(num_samples=n)
src = StatefulRangeNode(n=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = Prefetcher(node, prefetch_factor=8, snapshot_frequency=snapshot_frequency)
run_test_save_load_state(self, node, midpoint)
119 changes: 81 additions & 38 deletions test/nodes/test_snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,89 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import time

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.snapshot_store import DequeSnapshotStore
from torchdata.nodes.constants import QUEUE_TIMEOUT
from torchdata.nodes.exception_wrapper import StartupExceptionWrapper
from torchdata.nodes.snapshot_store import QueueSnapshotStore


class TestDequeSnapshotStore(TestCase):
def test_snapshot_store(self) -> None:
store = DequeSnapshotStore()
store.append({"a": 1}, 0)
store.append({"a": 2}, 10)

self.assertEqual(len(store._deque), 2)

val = store.pop_version(0)
self.assertEqual(val, {"a": 1})
self.assertEqual(len(store._deque), 1)
val = store.pop_version(1)
self.assertIsNone(val)
self.assertEqual(len(store._deque), 1)
val = store.pop_version(7)
self.assertIsNone(val)
self.assertEqual(len(store._deque), 1)
val = store.pop_version(10)
self.assertEqual(val, {"a": 2})
self.assertEqual(len(store._deque), 0)

val = store.pop_version(11)
self.assertIsNone(val)
self.assertEqual(len(store._deque), 0)

with self.assertRaisesRegex(ValueError, "is not strictly greater than"):
store.append({"a": 3}, 3)

self.assertEqual(len(store._deque), 0)

with self.assertRaisesRegex(ValueError, "is not strictly greater than"):
store.append({"a": 4}, 10)
self.assertEqual(len(store._deque), 0)

store.append({"a": 4}, 11)
store.append({"a": 5}, 19)
val = store.pop_version(19)
self.assertEqual(val, {"a": 5})
self.assertEqual(len(store._deque), 0)
for _ in range(100):
store = QueueSnapshotStore()
store.append({"a": 1}, 0)
store.append({"a": 2}, 10)

self.assertEqual(len(store._q.queue), 2)

val = store.pop_version(0)
self.assertEqual(val, {"a": 1})
self.assertEqual(len(store._q.queue), 1)
val = store.pop_version(1)
self.assertIsNone(val)
self.assertEqual(len(store._q.queue), 1)
val = store.pop_version(7)
self.assertIsNone(val)
self.assertEqual(len(store._q.queue), 1)
val = store.pop_version(10)
self.assertEqual(val, {"a": 2})
self.assertEqual(len(store._q.queue), 0)

val = store.pop_version(11)
self.assertIsNone(val)
self.assertEqual(len(store._q.queue), 0)

with self.assertRaisesRegex(ValueError, "is not strictly greater than"):
store.append({"a": 3}, 3)

self.assertEqual(len(store._q.queue), 0)

with self.assertRaisesRegex(ValueError, "is not strictly greater than"):
store.append({"a": 4}, 10)
self.assertEqual(len(store._q.queue), 0)

store.append({"a": 4}, 11)
store.append({"a": 5}, 19)
val = store.pop_version(19)
self.assertEqual(val, {"a": 5})
self.assertEqual(len(store._q.queue), 0)

def test_init_error(self) -> None:
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"):
store.get_initial_snapshot(thread, sleep_time)
thread.join()

def test_timeout_error(self) -> None:
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=time.sleep, args=(sleep_time,))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"):
store.get_initial_snapshot(thread, sleep_time * 0.1)
thread.join()

def test_thread_dead_error(self) -> None:
# Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards
store = QueueSnapshotStore()
thread = threading.Thread(target=time.sleep, args=(QUEUE_TIMEOUT * 3.0,))
thread.start()
with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"):
store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 4.0)
thread.join()


def _worker_init_error(store, sleep_time):
# time.sleep(0.1 * sleep_time)
try:
raise RuntimeError("Test Startup Exception")
except Exception as e:
e = StartupExceptionWrapper(where="_worker_init_error")
store.append_initial_snapshot(e)
time.sleep(sleep_time)
30 changes: 30 additions & 0 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self._next_start = state_dict["_num_yielded"]


class StatefulRangeNode(BaseNode[Dict[str, int]]):
def __init__(self, n: int) -> None:
super().__init__()
self.n = n
self.i = 0
self.num_resets = 0

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self.i = initial_state["i"]
self.num_resets = initial_state["num_resets"]
else:
self.i = 0
self.num_resets += 1

def next(self) -> Iterator[Dict[str, int]]:
if self.i == self.n:
raise StopIteration()
ret = {"i": self.i, "resets": self.num_resets}
self.i += 1
return ret

def get_state(self) -> Dict[str, Any]:
return {
"i": self.i,
"num_resets": self.num_resets,
}


def run_test_save_load_state(test, node: BaseNode, midpoint: int):
##############################
# Generate initial, midpoint, and end state_dict's
Expand Down
12 changes: 8 additions & 4 deletions torchdata/nodes/_populate_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def _populate_queue(
# Include a monotonic index starting from 0 to each item in the queue
idx = MonotonicIndex()

def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None):
def _put(
item,
block: bool = True,
snapshot: Optional[Union[Dict[str, Any], StartupExceptionWrapper]] = None,
):
_idx = idx.get()
if snapshot:
snapshot_store.append(snapshot=snapshot, version=_idx)
Expand All @@ -56,18 +60,18 @@ def _put(item, block: bool = True, snapshot: Optional[Dict[str, Any]] = None):
assert (
isinstance(snapshot_frequency, int) and snapshot_frequency >= 0
), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}"
src_iter = iter(source)
snapshot_store.append_initial_snapshot(snapshot=source.state_dict())
except Exception:
e = StartupExceptionWrapper(where="in _populate_queue startup for device")
_put(e, block=False)
snapshot_store.append_initial_snapshot(snapshot=e)
return

yielded = 0
while not stop_event.is_set():
if not semaphore.acquire(blocking=True, timeout=QUEUE_TIMEOUT):
continue
try:
item = next(src_iter) # FIXME: This may hang!
item = next(source) # FIXME: This may hang!
yielded += 1
snapshot = None
if snapshot_frequency > 0 and yielded % snapshot_frequency == 0:
Expand Down
14 changes: 14 additions & 0 deletions torchdata/nodes/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,23 @@ def __init__(self, root: BaseNode[T], restart_on_stop_iteration: bool = True):
self.restart_on_stop_iteration = restart_on_stop_iteration
self._next_iter_state_dict: Optional[Dict[str, Any]] = None
self._it: Optional[LoaderIterator[T]] = None
# Tracks whether an iterator was created solely for getting a state_dict, in which case
# we don't want to reset the iterator. Consider these two cases, which should behave the same
# it = iter(loader)
# sd = loader.state_dict() # No extra __iter__ call as _it already exists
# for _ in it: ...
# --------
# sd = loader.state_dict() # Calls __iter__ since _it is None
# it = iter(loader) # We don't want to reset the iterator here again
# for _ in it: ...
self._iter_for_state_dict: bool = False

def __iter__(self):
if self._it is None:
self._it = LoaderIterator(self)
elif self._iter_for_state_dict:
self._iter_for_state_dict = False
return self._it # This was already pre-called to get a state dict

if self._next_iter_state_dict is not None:
self._it.reset(initial_state=self._next_iter_state_dict)
Expand All @@ -31,6 +44,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
def state_dict(self) -> Dict[str, Any]:
if self._it is None:
iter(self)
self._iter_for_state_dict = True
return self._it.state_dict() # type:ignore[union-attr]


Expand Down
29 changes: 20 additions & 9 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@

import queue
import threading
import time
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union

import torch.multiprocessing as mp
from torchdata.nodes.base_node import BaseNode, T
from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper
from torchdata.nodes.snapshot_store import DequeSnapshotStore, SnapshotStore
from torchdata.nodes.snapshot_store import QueueSnapshotStore, SnapshotStore

from ._apply_udf import _apply_udf

from ._populate_queue import _populate_queue

from .constants import QUEUE_TIMEOUT

ACK_TIMEOUT = 300 # Timeout after 5 minutes


# We define this protocol for type checking
class _MultiprocessContext(Protocol):
Expand Down Expand Up @@ -147,7 +150,7 @@ def __init__(
else:
self._snapshot = None
self.source.reset()
self._snapshot_store = DequeSnapshotStore()
self._snapshot_store = QueueSnapshotStore()

self._read_thread = threading.Thread(
target=_populate_queue,
Expand Down Expand Up @@ -192,6 +195,9 @@ def __init__(
if self.in_order:
self._sort_thread.start()

time.sleep(0.01)
self._snapshot = self._snapshot_store.get_initial_snapshot(thread=self._read_thread, timeout=ACK_TIMEOUT)

for i in range(fast_forward):
try:
next(self)
Expand Down Expand Up @@ -250,13 +256,14 @@ def __del__(self):
def _shutdown(self):
self._stop.set()
self._mp_stop.set()
if self._read_thread.is_alive():
if hasattr(self, "_read_thread") and self._read_thread.is_alive():
self._read_thread.join(timeout=QUEUE_TIMEOUT * 5)
if self._sort_thread.is_alive():
if hasattr(self, "_sort_thread") and self._sort_thread.is_alive():
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
for t in self._workers:
if t.is_alive():
t.join(timeout=QUEUE_TIMEOUT * 5)
if hasattr(self, "_workers"):
for t in self._workers:
if t.is_alive():
t.join(timeout=QUEUE_TIMEOUT * 5)


class ParallelMapper(BaseNode[T]):
Expand Down Expand Up @@ -403,7 +410,7 @@ def __init__(
else:
self._snapshot = None
self.source.reset()
self._snapshot_store = DequeSnapshotStore()
self._snapshot_store = QueueSnapshotStore()
self._thread = threading.Thread(
target=self.worker,
args=(
Expand All @@ -417,6 +424,10 @@ def __init__(
daemon=True,
)
self._thread.start()

# Try and get initial snapshot
self._snapshot = self._snapshot_store.get_initial_snapshot(thread=self._thread, timeout=ACK_TIMEOUT)

for i in range(self._fast_forward):
try:
next(self)
Expand Down Expand Up @@ -471,5 +482,5 @@ def __del__(self):

def _shutdown(self):
self._stop_event.set()
if self._thread.is_alive():
if hasattr(self, "_thread") and self._thread.is_alive():
self._thread.join(timeout=QUEUE_TIMEOUT * 5)
Loading

0 comments on commit b3ab645

Please sign in to comment.