Skip to content

Commit b147401

Browse files
ZainRizvipytorchmergebot
authored andcommitted
Test Reordering: Run previously failing tests first (pytorch#101123)
Makes the CI prioritize running any test files that had a failing test in a previous iteration of the given PR. A follow up to pytorch#100522 which makes the `.pytest_cache` available to use here A concrete example: 1. Person A pushes a new commit and creates a PR. 2. 2 hours later, test_im_now_broken.py fails 3. Person A attempts to fix the test, but the test is actually still broken 4. The CI, seeing that test_im_now_broken.py had failed on a previous run, will now prioritize running that test first. Instead of waiting another 2 hours to get a signal, Person A only needs to wait ~15 minutes (which is how long it takes for tests to start running) # Testing I modified a file to make the tests invoking it fail and triggered CI twice with this failure. First run: https://github.com/pytorch/pytorch/actions/runs/4963943209/jobs/8883800811 Test step took 1h 9m to run Second run: https://github.com/pytorch/pytorch/actions/runs/4965016776/jobs/8885657992 Test step failed within 2m 27s Pull Request resolved: pytorch#101123 Approved by: https://github.com/malfet, https://github.com/huydhn
1 parent b5ed606 commit b147401

File tree

3 files changed

+207
-26
lines changed

3 files changed

+207
-26
lines changed

tools/shared/logging_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def pluralize(count: int, singular_word: str, plural_word: str = "") -> str:
2+
if count == 1:
3+
return f"{count} {singular_word}"
4+
5+
if not plural_word:
6+
plural_word = f"{singular_word}s"
7+
8+
return f"{count} {plural_word}"
9+
10+
11+
def duration_to_str(seconds: float) -> str:
12+
if seconds < 0.00001:
13+
return "0s"
14+
elif seconds < 60:
15+
return f"{seconds:.1f}s"
16+
elif seconds < 3600:
17+
return f"{seconds / 60:.1f}m"
18+
else:
19+
return f"{seconds / 3600:.1f}h"

tools/test/test_test_selections.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
import io
2+
import json
13
import pathlib
24
import random
35
import sys
46
import unittest
57
from collections import defaultdict
6-
from typing import Dict, List, Tuple
8+
from typing import Any, Dict, List, Set, Tuple
9+
from unittest import mock
710

811
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
912
try:
1013
# using tools/ to optimize test run.
1114
sys.path.append(str(REPO_ROOT))
12-
from tools.testing.test_selections import calculate_shards, ShardedTest, THRESHOLD
15+
from tools.testing.test_selections import (
16+
_get_previously_failing_tests,
17+
calculate_shards,
18+
get_reordered_tests,
19+
ShardedTest,
20+
THRESHOLD,
21+
)
1322
except ModuleNotFoundError:
1423
print("Can't import required modules, exiting")
1524
exit(1)
@@ -328,5 +337,81 @@ def test_calculate_2_shards_against_optimal_shards(self) -> None:
328337
self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
329338

330339

340+
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
341+
file_object = io.StringIO()
342+
json.dump(contents, file_object)
343+
file_object.seek(0)
344+
return file_object
345+
346+
347+
class TestParsePrevTests(unittest.TestCase):
348+
@mock.patch("pathlib.Path.exists", return_value=False)
349+
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
350+
expected_failing_test_files: Set[str] = set()
351+
352+
found_tests = _get_previously_failing_tests()
353+
354+
self.assertSetEqual(expected_failing_test_files, found_tests)
355+
356+
@mock.patch("pathlib.Path.exists", return_value=True)
357+
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
358+
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
359+
expected_failing_test_files: Set[str] = set()
360+
361+
found_tests = _get_previously_failing_tests()
362+
363+
self.assertSetEqual(expected_failing_test_files, found_tests)
364+
mock_open.assert_called()
365+
366+
lastfailed_with_multiple_tests_per_file = {
367+
"test/test_car.py::TestCar::test_num[17]": True,
368+
"test/test_car.py::TestBar::test_num[25]": True,
369+
"test/test_far.py::TestFar::test_fun_copy[17]": True,
370+
"test/test_bar.py::TestBar::test_fun_copy[25]": True,
371+
}
372+
373+
@mock.patch("pathlib.Path.exists", return_value=True)
374+
@mock.patch(
375+
"builtins.open",
376+
return_value=mocked_file(lastfailed_with_multiple_tests_per_file),
377+
)
378+
def test_dedupes_failing_test_files(self, mock_exists: Any, mock_open: Any) -> None:
379+
expected_failing_test_files = {"test_car", "test_bar", "test_far"}
380+
found_tests = _get_previously_failing_tests()
381+
382+
self.assertSetEqual(expected_failing_test_files, found_tests)
383+
384+
@mock.patch(
385+
"tools.testing.test_selections._get_previously_failing_tests",
386+
return_value={"test4"},
387+
)
388+
@mock.patch(
389+
"tools.testing.test_selections._get_modified_tests",
390+
return_value={"test2", "test4"},
391+
)
392+
def test_get_reordered_tests(
393+
self, mock_get_prev_failing_tests: Any, mock_get_modified_tests: Any
394+
) -> None:
395+
tests = [
396+
ShardedTest(name="test1", shard=1, num_shards=2, time=600.0),
397+
ShardedTest(name="test2", shard=1, num_shards=2, time=500.0),
398+
ShardedTest(name="test3", shard=1, num_shards=2, time=400.0),
399+
ShardedTest(name="test4", shard=1, num_shards=2, time=300.0),
400+
ShardedTest(name="test5", shard=1, num_shards=2, time=200.0),
401+
]
402+
403+
expected_prioritized_tests = {"test4", "test2"}
404+
expected_remaining_tests = {"test1", "test3", "test5"}
405+
406+
prioritized_tests, remaining_tests = get_reordered_tests(tests)
407+
408+
# Just want to check the names of the tests
409+
prioritized_tests_name = {test.name for test in prioritized_tests}
410+
remaining_tests_name = {test.name for test in remaining_tests}
411+
412+
self.assertSetEqual(expected_prioritized_tests, prioritized_tests_name)
413+
self.assertSetEqual(expected_remaining_tests, remaining_tests_name)
414+
415+
331416
if __name__ == "__main__":
332417
unittest.main()

tools/testing/test_selections.py

+101-24
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import json
12
import math
23
import os
34
import subprocess
5+
from pathlib import Path
46

5-
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
7+
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple
8+
from warnings import warn
9+
10+
from tools.shared.logging_utils import duration_to_str, pluralize
611

712
from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
813

@@ -37,7 +42,7 @@ class ShardedTest(NamedTuple):
3742
name: str
3843
shard: int
3944
num_shards: int
40-
time: Optional[float]
45+
time: Optional[float] # In seconds
4146

4247
def __str__(self) -> str:
4348
return f"{self.name} {self.shard}/{self.num_shards}"
@@ -133,50 +138,122 @@ def _query_changed_test_files() -> List[str]:
133138
return lines
134139

135140

141+
def _get_previously_failing_tests() -> Set[str]:
142+
PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed")
143+
144+
if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists():
145+
warn(
146+
f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}"
147+
)
148+
return set()
149+
150+
with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f:
151+
last_failed_tests = json.load(f)
152+
153+
prioritized_tests = _parse_prev_failing_test_files(last_failed_tests)
154+
return _python_test_file_to_test_name(prioritized_tests)
155+
156+
157+
def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]:
158+
prioritized_tests = set()
159+
160+
# The keys are formatted as "test_file.py::test_class::test_method[params]"
161+
# We just need the test_file part
162+
for test in last_failed_tests:
163+
parts = test.split("::")
164+
if len(parts) > 1:
165+
test_file = parts[0]
166+
prioritized_tests.add(test_file)
167+
168+
return prioritized_tests
169+
170+
171+
def _get_modified_tests() -> Set[str]:
172+
try:
173+
changed_files = _query_changed_test_files()
174+
except Exception as e:
175+
warn(f"Can't query changed test files due to {e}")
176+
# If unable to get changed files from git, quit without doing any sorting
177+
return set()
178+
179+
return _python_test_file_to_test_name(set(changed_files))
180+
181+
182+
def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]:
183+
prefix = f"test{os.path.sep}"
184+
valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")}
185+
valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests}
186+
187+
return valid_tests
188+
189+
136190
def get_reordered_tests(
137191
tests: List[ShardedTest],
138192
) -> Tuple[List[ShardedTest], List[ShardedTest]]:
139193
"""
140194
Get the reordered test filename list based on github PR history or git changed file.
141195
We prioritize running test files that were changed.
142196
"""
143-
prioritized_tests: List[str] = []
144-
if len(prioritized_tests) == 0:
145-
try:
146-
changed_files = _query_changed_test_files()
147-
except Exception:
148-
# If unable to get changed files from git, quit without doing any sorting
149-
return ([], tests)
150-
151-
prefix = f"test{os.path.sep}"
152-
prioritized_tests = [
153-
f for f in changed_files if f.startswith(prefix) and f.endswith(".py")
154-
]
155-
prioritized_tests = [f[len(prefix) :] for f in prioritized_tests]
156-
prioritized_tests = [f[: -len(".py")] for f in prioritized_tests]
157-
print("Prioritized test from test file changes.")
197+
198+
def print_tests(tests: Set[str], test_group_description: str) -> None:
199+
if not tests:
200+
return
201+
202+
print(f"{test_group_description}:")
203+
for test in tests:
204+
print(f" {test}")
205+
206+
prioritized_tests: Set[str] = set()
207+
208+
pri_test = _get_previously_failing_tests()
209+
print_tests(
210+
pri_test, "If run, these tests will prioritized because they previously failed"
211+
)
212+
prioritized_tests |= pri_test
213+
214+
pri_test |= _get_modified_tests()
215+
print_tests(
216+
pri_test, "If run, these tests will be prioritized because they were modified"
217+
)
218+
prioritized_tests |= pri_test
158219

159220
bring_to_front = []
160221
the_rest = []
161222

223+
test_time_for_regular_tests_so_far = 0.0
224+
# how much sooner did we run prioritized tests compared to a naive ordering
225+
time_savings_sec = 0.0
226+
162227
for test in tests:
163228
if test.name in prioritized_tests:
164229
bring_to_front.append(test)
230+
# Calculate approx time saved by reordering
231+
time_savings_sec = test_time_for_regular_tests_so_far
165232
else:
166233
the_rest.append(test)
167-
if len(tests) == len(bring_to_front) + len(the_rest):
168-
print(
169-
f"reordering tests for PR:\n"
170-
f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n"
171-
)
172-
return (bring_to_front, the_rest)
173-
else:
234+
test_time_for_regular_tests_so_far += test.get_time()
235+
236+
if len(tests) != len(bring_to_front) + len(the_rest):
174237
print(
175238
f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
176239
f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
177240
)
178241
return ([], tests)
179242

243+
# TODO: Would be great to upload these stats to RDS/Rockset!
244+
test_cnt_str = pluralize(len(tests), "test")
245+
print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}")
246+
print(
247+
f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise"
248+
)
249+
250+
prioritized_test_names = [t.name for t in bring_to_front]
251+
print(f"Prioritized: {prioritized_test_names}")
252+
remaining_test_names = [t.name for t in the_rest]
253+
print(f"The Rest: {remaining_test_names}")
254+
255+
return (bring_to_front, the_rest)
256+
180257

181258
def get_test_case_configs(dirpath: str) -> None:
182259
get_slow_tests(dirpath=dirpath)

0 commit comments

Comments
 (0)