From 72c4d128505244057eaf4cfdf491598a7fa5e270 Mon Sep 17 00:00:00 2001 From: Andrew Ho Date: Wed, 4 Dec 2024 19:02:40 -0500 Subject: [PATCH] Fix test flakiness and test for it (#1380) * try to catch flakiness * raise in dead thread instead of exiting --- test/nodes/test_snapshot_store.py | 51 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/test/nodes/test_snapshot_store.py b/test/nodes/test_snapshot_store.py index 41f152575..892793933 100644 --- a/test/nodes/test_snapshot_store.py +++ b/test/nodes/test_snapshot_store.py @@ -13,7 +13,7 @@ from torchdata.nodes.snapshot_store import QueueSnapshotStore -class TestDequeSnapshotStore(TestCase): +class TestQueueSnapshotStore(TestCase): def test_snapshot_store(self) -> None: for _ in range(100): store = QueueSnapshotStore() @@ -55,38 +55,45 @@ def test_snapshot_store(self) -> None: self.assertEqual(len(store._q.queue), 0) def test_init_error(self) -> None: - store = QueueSnapshotStore() - sleep_time = 0.1 - thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time)) - thread.start() - with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"): - store.get_initial_snapshot(thread, sleep_time) - thread.join() + for _ in range(10): + store = QueueSnapshotStore() + sleep_time = 0.1 + thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time)) + thread.start() + with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"): + store.get_initial_snapshot(thread, sleep_time) + thread.join() def test_timeout_error(self) -> None: - store = QueueSnapshotStore() - sleep_time = 0.1 - thread = threading.Thread(target=time.sleep, args=(sleep_time,)) - thread.start() - with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"): - store.get_initial_snapshot(thread, sleep_time * 0.1) - thread.join() + for _ in range(10): + store = QueueSnapshotStore() + sleep_time = 0.1 + thread = threading.Thread(target=_worker_raises_after, args=(sleep_time,)) + thread.start() + with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"): + store.get_initial_snapshot(thread, sleep_time * 0.1) + thread.join() def test_thread_dead_error(self) -> None: # Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards - store = QueueSnapshotStore() - thread = threading.Thread(target=time.sleep, args=(QUEUE_TIMEOUT * 3.0,)) - thread.start() - with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"): - store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 4.0) - thread.join() + for _ in range(10): # Should be reliable + store = QueueSnapshotStore() + thread = threading.Thread(target=_worker_raises_after, args=(QUEUE_TIMEOUT * 3.0,)) + thread.start() + with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"): + store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 5.0) + thread.join() def _worker_init_error(store, sleep_time): - # time.sleep(0.1 * sleep_time) try: raise RuntimeError("Test Startup Exception") except Exception as e: e = StartupExceptionWrapper(where="_worker_init_error") store.append_initial_snapshot(e) time.sleep(sleep_time) + + +def _worker_raises_after(sleep_time): + time.sleep(sleep_time) + raise RuntimeError(f"Thread dying {sleep_time=}")