Skip to content

Commit

Permalink
fix(anta): Refactor ANTA runner (#656)
Browse files Browse the repository at this point in the history
* fix(anta): Refactor ANTA runner
  • Loading branch information
carl-baillargeon authored Apr 25, 2024
1 parent 15af88d commit 79fa2c9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
64 changes: 41 additions & 23 deletions anta/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@
import os
import resource
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from anta import GITHUB_SUGGESTION
from anta.logger import anta_log_exception, exc_to_str
from anta.models import AntaTest
from anta.tools import Catchtime

if TYPE_CHECKING:
from collections.abc import Coroutine

from anta.catalog import AntaCatalog, AntaTestDefinition
from anta.device import AntaDevice
from anta.inventory import AntaInventory
from anta.result_manager import ResultManager
from anta.result_manager.models import TestResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +111,7 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic
return selected_inventory


async def prepare_tests(
def prepare_tests(
inventory: AntaInventory, catalog: AntaCatalog, tests: set[str] | None, tags: set[str] | None
) -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None:
"""Prepare the tests to run.
Expand Down Expand Up @@ -154,7 +157,37 @@ async def prepare_tests(
return device_to_tests


async def main( # noqa: PLR0913,C901
def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]]) -> list[Coroutine[Any, Any, TestResult]]:
"""Get the coroutines for the ANTA run.
Args:
----
selected_tests: A mapping of devices to the tests to run. The selected tests are generated by the `prepare_tests` function.
Returns
-------
The list of coroutines to run.
"""
coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
coros.append(test_instance.test())
except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught
# An AntaTest instance is potentially user-defined code.
# We need to catch everything and exit gracefully with an error message.
message = "\n".join(
[
f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.",
f"If this is not a custom test implementation: {GITHUB_SUGGESTION}",
],
)
anta_log_exception(e, message, logger)
return coros


async def main( # noqa: PLR0913
manager: ResultManager,
inventory: AntaInventory,
catalog: AntaCatalog,
Expand Down Expand Up @@ -196,7 +229,7 @@ async def main( # noqa: PLR0913,C901
return

with Catchtime(logger=logger, message="Preparing the tests"):
selected_tests = await prepare_tests(selected_inventory, catalog, tests, tags)
selected_tests = prepare_tests(selected_inventory, catalog, tests, tags)
if selected_tests is None:
return

Expand All @@ -217,34 +250,19 @@ async def main( # noqa: PLR0913,C901
"Please consult the ANTA FAQ."
)

coros = []
for device, test_definitions in selected_tests.items():
for test in test_definitions:
try:
test_instance = test.test(device=device, inputs=test.inputs)
coros.append(test_instance.test())
except Exception as e: # noqa: PERF203, pylint: disable=broad-exception-caught
# An AntaTest instance is potentially user-defined code.
# We need to catch everything and exit gracefully with an error message.
message = "\n".join(
[
f"There is an error when creating test {test.test.__module__}.{test.test.__name__}.",
f"If this is not a custom test implementation: {GITHUB_SUGGESTION}",
],
)
anta_log_exception(e, message, logger)
coroutines = get_coroutines(selected_tests)

if dry_run:
logger.info("Dry-run mode, exiting before running the tests.")
for coro in coros:
for coro in coroutines:
coro.close()
return

if AntaTest.progress is not None:
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coros))
AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines))

with Catchtime(logger=logger, message="Running ANTA tests"):
test_results = await asyncio.gather(*coros)
test_results = await asyncio.gather(*coroutines)
for r in test_results:
manager.add(r)

Expand Down
4 changes: 2 additions & 2 deletions tests/units/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def test_prepare_tests(
caplog.set_level(logging.INFO)

catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml"))
selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None)
selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=tags, tests=None)

if selected_tests is None:
assert expected_tests_count == 0
Expand All @@ -180,7 +180,7 @@ async def test_prepare_tests_with_specific_tests(caplog: pytest.LogCaptureFixtur
caplog.set_level(logging.INFO)

catalog: AntaCatalog = AntaCatalog.parse(str(DATA_DIR / "test_catalog_with_tags.yml"))
selected_tests = await prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"})
selected_tests = prepare_tests(inventory=test_inventory, catalog=catalog, tags=None, tests={"VerifyMlagStatus", "VerifyUptime"})

assert selected_tests is not None
assert len(selected_tests) == 3
Expand Down

0 comments on commit 79fa2c9

Please sign in to comment.