|
| 1 | +import json |
1 | 2 | import math
|
2 | 3 | import os
|
3 | 4 | import subprocess
|
| 5 | +from pathlib import Path |
4 | 6 |
|
5 |
| -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple |
| 7 | +from typing import Callable, Dict, List, NamedTuple, Optional, Set, Tuple |
| 8 | +from warnings import warn |
| 9 | + |
| 10 | +from tools.shared.logging_utils import duration_to_str, pluralize |
6 | 11 |
|
7 | 12 | from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests
|
8 | 13 |
|
@@ -37,7 +42,7 @@ class ShardedTest(NamedTuple):
|
37 | 42 | name: str
|
38 | 43 | shard: int
|
39 | 44 | num_shards: int
|
40 |
| - time: Optional[float] |
| 45 | + time: Optional[float] # In seconds |
41 | 46 |
|
42 | 47 | def __str__(self) -> str:
|
43 | 48 | return f"{self.name} {self.shard}/{self.num_shards}"
|
@@ -133,50 +138,122 @@ def _query_changed_test_files() -> List[str]:
|
133 | 138 | return lines
|
134 | 139 |
|
135 | 140 |
|
| 141 | +def _get_previously_failing_tests() -> Set[str]: |
| 142 | + PYTEST_FAILED_TESTS_CACHE_FILE_PATH = Path(".pytest_cache/v/cache/lastfailed") |
| 143 | + |
| 144 | + if not PYTEST_FAILED_TESTS_CACHE_FILE_PATH.exists(): |
| 145 | + warn( |
| 146 | + f"No pytorch cache found at {PYTEST_FAILED_TESTS_CACHE_FILE_PATH.absolute()}" |
| 147 | + ) |
| 148 | + return set() |
| 149 | + |
| 150 | + with open(PYTEST_FAILED_TESTS_CACHE_FILE_PATH, "r") as f: |
| 151 | + last_failed_tests = json.load(f) |
| 152 | + |
| 153 | + prioritized_tests = _parse_prev_failing_test_files(last_failed_tests) |
| 154 | + return _python_test_file_to_test_name(prioritized_tests) |
| 155 | + |
| 156 | + |
| 157 | +def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]: |
| 158 | + prioritized_tests = set() |
| 159 | + |
| 160 | + # The keys are formatted as "test_file.py::test_class::test_method[params]" |
| 161 | + # We just need the test_file part |
| 162 | + for test in last_failed_tests: |
| 163 | + parts = test.split("::") |
| 164 | + if len(parts) > 1: |
| 165 | + test_file = parts[0] |
| 166 | + prioritized_tests.add(test_file) |
| 167 | + |
| 168 | + return prioritized_tests |
| 169 | + |
| 170 | + |
| 171 | +def _get_modified_tests() -> Set[str]: |
| 172 | + try: |
| 173 | + changed_files = _query_changed_test_files() |
| 174 | + except Exception as e: |
| 175 | + warn(f"Can't query changed test files due to {e}") |
| 176 | + # If unable to get changed files from git, quit without doing any sorting |
| 177 | + return set() |
| 178 | + |
| 179 | + return _python_test_file_to_test_name(set(changed_files)) |
| 180 | + |
| 181 | + |
| 182 | +def _python_test_file_to_test_name(tests: Set[str]) -> Set[str]: |
| 183 | + prefix = f"test{os.path.sep}" |
| 184 | + valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")} |
| 185 | + valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests} |
| 186 | + |
| 187 | + return valid_tests |
| 188 | + |
| 189 | + |
136 | 190 | def get_reordered_tests(
|
137 | 191 | tests: List[ShardedTest],
|
138 | 192 | ) -> Tuple[List[ShardedTest], List[ShardedTest]]:
|
139 | 193 | """
|
140 | 194 | Get the reordered test filename list based on github PR history or git changed file.
|
141 | 195 | We prioritize running test files that were changed.
|
142 | 196 | """
|
143 |
| - prioritized_tests: List[str] = [] |
144 |
| - if len(prioritized_tests) == 0: |
145 |
| - try: |
146 |
| - changed_files = _query_changed_test_files() |
147 |
| - except Exception: |
148 |
| - # If unable to get changed files from git, quit without doing any sorting |
149 |
| - return ([], tests) |
150 |
| - |
151 |
| - prefix = f"test{os.path.sep}" |
152 |
| - prioritized_tests = [ |
153 |
| - f for f in changed_files if f.startswith(prefix) and f.endswith(".py") |
154 |
| - ] |
155 |
| - prioritized_tests = [f[len(prefix) :] for f in prioritized_tests] |
156 |
| - prioritized_tests = [f[: -len(".py")] for f in prioritized_tests] |
157 |
| - print("Prioritized test from test file changes.") |
| 197 | + |
| 198 | + def print_tests(tests: Set[str], test_group_description: str) -> None: |
| 199 | + if not tests: |
| 200 | + return |
| 201 | + |
| 202 | + print(f"{test_group_description}:") |
| 203 | + for test in tests: |
| 204 | + print(f" {test}") |
| 205 | + |
| 206 | + prioritized_tests: Set[str] = set() |
| 207 | + |
| 208 | + pri_test = _get_previously_failing_tests() |
| 209 | + print_tests( |
| 210 | + pri_test, "If run, these tests will prioritized because they previously failed" |
| 211 | + ) |
| 212 | + prioritized_tests |= pri_test |
| 213 | + |
| 214 | + pri_test |= _get_modified_tests() |
| 215 | + print_tests( |
| 216 | + pri_test, "If run, these tests will be prioritized because they were modified" |
| 217 | + ) |
| 218 | + prioritized_tests |= pri_test |
158 | 219 |
|
159 | 220 | bring_to_front = []
|
160 | 221 | the_rest = []
|
161 | 222 |
|
| 223 | + test_time_for_regular_tests_so_far = 0.0 |
| 224 | + # how much sooner did we run prioritized tests compared to a naive ordering |
| 225 | + time_savings_sec = 0.0 |
| 226 | + |
162 | 227 | for test in tests:
|
163 | 228 | if test.name in prioritized_tests:
|
164 | 229 | bring_to_front.append(test)
|
| 230 | + # Calculate approx time saved by reordering |
| 231 | + time_savings_sec = test_time_for_regular_tests_so_far |
165 | 232 | else:
|
166 | 233 | the_rest.append(test)
|
167 |
| - if len(tests) == len(bring_to_front) + len(the_rest): |
168 |
| - print( |
169 |
| - f"reordering tests for PR:\n" |
170 |
| - f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n" |
171 |
| - ) |
172 |
| - return (bring_to_front, the_rest) |
173 |
| - else: |
| 234 | + test_time_for_regular_tests_so_far += test.get_time() |
| 235 | + |
| 236 | + if len(tests) != len(bring_to_front) + len(the_rest): |
174 | 237 | print(
|
175 | 238 | f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
|
176 | 239 | f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n"
|
177 | 240 | )
|
178 | 241 | return ([], tests)
|
179 | 242 |
|
| 243 | + # TODO: Would be great to upload these stats to RDS/Rockset! |
| 244 | + test_cnt_str = pluralize(len(tests), "test") |
| 245 | + print(f"Reordering tests: Prioritizing {len(bring_to_front)} of {test_cnt_str}") |
| 246 | + print( |
| 247 | + f"Prioritized tests estimated to run up to {duration_to_str(time_savings_sec)} sooner than they would've otherwise" |
| 248 | + ) |
| 249 | + |
| 250 | + prioritized_test_names = [t.name for t in bring_to_front] |
| 251 | + print(f"Prioritized: {prioritized_test_names}") |
| 252 | + remaining_test_names = [t.name for t in the_rest] |
| 253 | + print(f"The Rest: {remaining_test_names}") |
| 254 | + |
| 255 | + return (bring_to_front, the_rest) |
| 256 | + |
180 | 257 |
|
181 | 258 | def get_test_case_configs(dirpath: str) -> None:
|
182 | 259 | get_slow_tests(dirpath=dirpath)
|
|
0 commit comments