Skip to content

Commit

Permalink
Fix test flakiness and test for it (#1380)
Browse files Browse the repository at this point in the history
* try to catch flakiness

* raise in dead thread instead of exiting
  • Loading branch information
andrewkho authored Dec 5, 2024
1 parent 60380f1 commit 72c4d12
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions test/nodes/test_snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchdata.nodes.snapshot_store import QueueSnapshotStore


class TestDequeSnapshotStore(TestCase):
class TestQueueSnapshotStore(TestCase):
def test_snapshot_store(self) -> None:
for _ in range(100):
store = QueueSnapshotStore()
Expand Down Expand Up @@ -55,38 +55,45 @@ def test_snapshot_store(self) -> None:
self.assertEqual(len(store._q.queue), 0)

def test_init_error(self) -> None:
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"):
store.get_initial_snapshot(thread, sleep_time)
thread.join()
for _ in range(10):
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"):
store.get_initial_snapshot(thread, sleep_time)
thread.join()

def test_timeout_error(self) -> None:
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=time.sleep, args=(sleep_time,))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"):
store.get_initial_snapshot(thread, sleep_time * 0.1)
thread.join()
for _ in range(10):
store = QueueSnapshotStore()
sleep_time = 0.1
thread = threading.Thread(target=_worker_raises_after, args=(sleep_time,))
thread.start()
with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"):
store.get_initial_snapshot(thread, sleep_time * 0.1)
thread.join()

def test_thread_dead_error(self) -> None:
# Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards
store = QueueSnapshotStore()
thread = threading.Thread(target=time.sleep, args=(QUEUE_TIMEOUT * 3.0,))
thread.start()
with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"):
store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 4.0)
thread.join()
for _ in range(10): # Should be reliable
store = QueueSnapshotStore()
thread = threading.Thread(target=_worker_raises_after, args=(QUEUE_TIMEOUT * 3.0,))
thread.start()
with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"):
store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 5.0)
thread.join()


def _worker_init_error(store, sleep_time):
# time.sleep(0.1 * sleep_time)
try:
raise RuntimeError("Test Startup Exception")
except Exception as e:
e = StartupExceptionWrapper(where="_worker_init_error")
store.append_initial_snapshot(e)
time.sleep(sleep_time)


def _worker_raises_after(sleep_time):
time.sleep(sleep_time)
raise RuntimeError(f"Thread dying {sleep_time=}")

0 comments on commit 72c4d12

Please sign in to comment.