diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e9430db3..824b6a96b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,6 +74,7 @@ repos: - pytest - pytest-codspeed - respx + - pydantic-settings - repo: https://github.com/codespell-project/codespell rev: v2.4.1 @@ -123,6 +124,7 @@ repos: pass_filenames: false additional_dependencies: - anta[cli] + - pydantic-settings - id: doc-snippets name: Generate doc snippets entry: >- @@ -134,3 +136,4 @@ repos: pass_filenames: false additional_dependencies: - anta[cli] + - pydantic-settings diff --git a/anta/_runner.py b/anta/_runner.py new file mode 100644 index 000000000..e3bb7d375 --- /dev/null +++ b/anta/_runner.py @@ -0,0 +1,473 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""ANTA runner classes.""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, PrivateAttr + +from anta import GITHUB_SUGGESTION +from anta.catalog import AntaCatalog, AntaTestDefinition +from anta.cli.console import console +from anta.device import AntaDevice +from anta.inventory import AntaInventory +from anta.logger import anta_log_exception +from anta.models import AntaTest +from anta.result_manager import ResultManager +from anta.settings import AntaRunnerSettings +from anta.tools import Catchtime, limit_concurrency + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Coroutine + + from anta.result_manager.models import TestResult + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class AntaRunnerInventoryStats: + """Store inventory filtering statistics of an ANTA run. + + This class maintains counters for tracking how the device inventory is filtered + during test setup. + + Attributes + ---------- + total : int + Total number of devices in the original inventory before filtering. + filtered_by_tags : int + Number of devices excluded due to tag filtering. + connection_failed : int + Number of devices that failed to establish a connection. + established : int + Number of devices with successfully established connections that are + included in the final test run. + """ + + total: int + filtered_by_tags: int + connection_failed: int + established: int + + +class AntaRunnerFilter(BaseModel): + """Define a filter for an ANTA run. + + The filter determines which devices and tests to include in a run, and how to + filter them with tags. This class is used with the `AntaRunner.run()` method. + + Attributes + ---------- + devices : set[str] | None, optional + Set of device names to run tests on. If `None`, includes all devices in + the inventory. Commonly set via the NRFU CLI `--device/-d` option. + tests : set[str] | None, optional + Set of test names to run. If `None`, runs all available tests in the + catalog. Commonly set via the NRFU CLI `--test/-t` option. + tags : set[str] | None, optional + Set of tags used to filter both devices and tests. A device or test + must match any of the provided tags to be included. Commonly set via + the NRFU CLI `--tags` option. + established_only : bool, default=True + When `True`, only includes devices with established connections in the + test run. + """ + + model_config = ConfigDict(frozen=True, extra="forbid") + devices: set[str] | None = None + tests: set[str] | None = None + tags: set[str] | None = None + established_only: bool = True + + +class AntaRunner(BaseModel): + """Run and manage ANTA test execution. + + This class orchestrates the execution of ANTA tests across network devices. It handles + inventory filtering, test selection, concurrent test execution, and result collection. + + Attributes + ---------- + inventory : AntaInventory + Inventory of network devices to test. + catalog : AntaCatalog + Catalog of available tests to run. + manager : ResultManager | None, optional + Manager for collecting and storing test results. If `None`, a new manager + is returned for each run, otherwise the provided manager is used + and results from subsequent runs are appended to it. + _selected_inventory : AntaInventory | None + Internal state of filtered inventory for current run. + _selected_tests : defaultdict[AntaDevice, set[AntaTestDefinition]] | None + Mapping of devices to their selected tests for current run. + _inventory_stats : AntaRunnerInventoryStats | None + Statistics about inventory filtering for current run. + _total_tests : int + Total number of tests to run in current execution. + _potential_connections : float | None + Total potential concurrent connections needed for current run. + `None` if unknown. + _settings : AntaRunnerSettings + Internal settings loaded from environment variables. See the class definition + in the `anta.settings` module for details. + + Notes + ----- + After initializing an `AntaRunner` instance, tests should only be executed through + the `run()` method. This method manages the complete test lifecycle including setup, + execution, and cleanup. + + All internal methods and state (prefixed with `_`) are managed by the `run()` method + and should not be called directly. The internal state is reset between runs to + ensure clean execution. + + + Examples + -------- + ```python + import asyncio + + from anta._runner import AntaRunner, AntaRunnerFilter + from anta.catalog import AntaCatalog + from anta.inventory import AntaInventory + + inventory = AntaInventory.parse( + filename="anta_inventory.yml", + username="arista", + password="arista", + ) + catalog = AntaCatalog.parse(filename="anta_catalog.yml") + + # Create an ANTA runner + runner = AntaRunner(inventory=inventory, catalog=catalog) + + # Run all tests + first_run_results = asyncio.run(runner.run()) + + # Run with filters + second_run_results = asyncio.run(runner.run(scope=AntaRunnerFilter(tags={"leaf"}))) + ``` + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + inventory: AntaInventory + catalog: AntaCatalog + manager: ResultManager | None = None + + # Internal attributes set during setup phases before each run + _selected_inventory: AntaInventory | None = PrivateAttr(default=None) + _selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]] | None = PrivateAttr(default=None) + _inventory_stats: AntaRunnerInventoryStats | None = PrivateAttr(default=None) + _total_tests: int = PrivateAttr(default=0) + _potential_connections: float | None = PrivateAttr(default=None) + + # Internal settings loaded from environment variables + _settings: AntaRunnerSettings = PrivateAttr(default_factory=AntaRunnerSettings) + + def reset(self) -> None: + """Reset the internal attributes of the ANTA runner.""" + self._selected_inventory: AntaInventory | None = None + self._selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]] | None = None + self._inventory_stats: AntaRunnerInventoryStats | None = None + self._total_tests: int = 0 + self._potential_connections: float | None = None + + async def run(self, filters: AntaRunnerFilter | None = None, *, dry_run: bool = False) -> ResultManager: + """Run ANTA. + + Parameters + ---------- + filters + Filters for the ANTA run. If None, runs all tests on all devices. + dry_run + Dry-run mode flag. If True, runs all setup steps but does not execute tests. + """ + filters = filters or AntaRunnerFilter() + + # Cleanup the instance before each run + self.reset() + self.catalog.clear_indexes() + manager = ResultManager() if self.manager is None else self.manager + + if not self.catalog.tests: + logger.info("The list of tests is empty, exiting") + return manager + + with Catchtime(logger=logger, message="Preparing ANTA NRFU Run"): + # Set up inventory + if not await self._setup_inventory(filters, dry_run=dry_run): + return manager + + # Set up tests + with Catchtime(logger=logger, message="Preparing Tests"): + if not self._setup_tests(filters): + return manager + + # Build the test generator + test_generator = self._test_generator(manager if dry_run else None) + + # Log run information + self._log_run_information(dry_run=dry_run) + + if dry_run: + logger.info("Dry-run mode, exiting before running the tests.") + async for test in test_generator: + test.close() + return manager + + if AntaTest.progress is not None: + AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=self._total_tests) + + with Catchtime(logger=logger, message="Running Tests"): + async for result in limit_concurrency(test_generator, limit=self._settings.max_concurrency): + manager.add(await result) + + self._log_cache_statistics() + return manager + + async def _setup_inventory(self, filters: AntaRunnerFilter, *, dry_run: bool = False) -> bool: + """Set up the inventory for the ANTA run.""" + total_devices = len(self.inventory) + + # In dry-run mode, set the selected inventory to the full inventory + if dry_run: + self._selected_inventory = self.inventory + self._inventory_stats = AntaRunnerInventoryStats(total=total_devices, filtered_by_tags=0, connection_failed=0, established=0) + return True + + # If the inventory is empty, exit + if total_devices == 0: + logger.info("The inventory is empty, exiting") + return False + + # Filter the inventory based on the provided filters any + filtered_inventory = self.inventory.get_inventory(tags=filters.tags, devices=filters.devices) if filters.tags or filters.devices else self.inventory + filtered_by_tags = total_devices - len(filtered_inventory) + + # Connect to devices + with Catchtime(logger=logger, message="Connecting to devices"): + await filtered_inventory.connect_inventory() + + # Remove devices that are unreachable + self._selected_inventory = filtered_inventory.get_inventory(established_only=filters.established_only) + connection_failed = len(filtered_inventory) - len(self._selected_inventory) + + # If there are no devices in the inventory after filtering, exit + if not self._selected_inventory.devices: + # Build message parts + tag_msg = f"matching the tags {filters.tags} " if filters.tags else "" + device_msg = f" Selected devices: {filters.devices} " if filters.devices is not None else "" + logger.warning("No reachable device %swas found.%s", tag_msg, device_msg) + return False + + self._inventory_stats = AntaRunnerInventoryStats( + total=total_devices, filtered_by_tags=filtered_by_tags, connection_failed=connection_failed, established=len(self._selected_inventory) + ) + return True + + def _setup_tests(self, filters: AntaRunnerFilter) -> bool: + """Set up tests for the ANTA run.""" + if self._selected_inventory is None: + msg = "The selected inventory is not available. ANTA must be executed through AntaRunner.run()" + raise RuntimeError(msg) + + # Build indexes for the catalog. If `filters.tests` is set, filter the indexes based on these tests + self.catalog.build_indexes(filtered_tests=filters.tests) + + # Using a set to avoid inserting duplicate tests + device_to_tests: defaultdict[AntaDevice, set[AntaTestDefinition]] = defaultdict(set) + total_tests = 0 + total_connections = 0 + all_have_connections = True + + # Create the device to tests mapping from the tags and calculate connection stats + for device in self._selected_inventory.devices: + if filters.tags: + # If there are CLI tags, execute tests with matching tags for this device + if not (matching_tags := filters.tags.intersection(device.tags)): + # The device does not have any selected tag, skipping + continue + device_to_tests[device].update(self.catalog.get_tests_by_tags(matching_tags)) + else: + # If there is no CLI tags, execute all tests that do not have any tags + device_to_tests[device].update(self.catalog.tag_to_tests[None]) + + # Then add the tests with matching tags from device tags + device_to_tests[device].update(self.catalog.get_tests_by_tags(device.tags)) + + total_tests += len(device_to_tests[device]) + + # Check device connections + if not hasattr(device, "max_connections") or device.max_connections is None: + all_have_connections = False + else: + total_connections += device.max_connections + + if total_tests == 0: + tag_msg = f" matching the tags {filters.tags} " if filters.tags else " " + logger.warning( + "There are no tests%sto run in the current test catalog and device inventory, please verify your inputs.", + tag_msg, + ) + return False + + self._selected_tests = device_to_tests + self._total_tests = total_tests + self._potential_connections = None if not all_have_connections else total_connections + return True + + async def _test_generator(self, manager: ResultManager | None = None) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]: + """Generate test coroutines for the ANTA run.""" + if self._selected_tests is None: + msg = "The selected tests are not available. ANTA must be executed through AntaRunner.run()" + raise RuntimeError(msg) + + logger.debug("ANTA run scheduling strategy: %s", self._settings.scheduling_strategy) + + device_to_tests: dict[AntaDevice, list[AntaTestDefinition]] = { + device: sorted(tests, key=lambda td: td.test.__name__) for device, tests in self._selected_tests.items() + } + + if self._settings.scheduling_strategy == "device-by-device": + async for coro in self._generate_device_by_device(device_to_tests, manager): + yield coro + elif self._settings.scheduling_strategy == "device-by-count": + async for coro in self._generate_device_by_count(device_to_tests, self._settings.scheduling_tests_per_device, manager): + yield coro + else: + # Default to round-robin + async for coro in self._generate_round_robin(device_to_tests, manager): + yield coro + + async def _generate_round_robin( + self, device_to_tests: dict[AntaDevice, list[AntaTestDefinition]], manager: ResultManager | None = None + ) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]: + """Yield one test per device in each round.""" + while any(device_to_tests.values()): + for device, tests in device_to_tests.items(): + if tests: + test_def = tests.pop(0) + coro = self._create_test_coroutine(test_def, device, manager) + if coro is not None: + yield coro + + async def _generate_device_by_device( + self, device_to_tests: dict[AntaDevice, list[AntaTestDefinition]], manager: ResultManager | None = None + ) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]: + """Yield all tests for one device before moving to the next.""" + for device, tests in device_to_tests.items(): + while tests: + test_def = tests.pop(0) + coro = self._create_test_coroutine(test_def, device, manager) + if coro is not None: + yield coro + + async def _generate_device_by_count( + self, + device_to_tests: dict[AntaDevice, list[AntaTestDefinition]], + tests_per_device: int, + manager: ResultManager | None = None, + ) -> AsyncGenerator[Coroutine[Any, Any, TestResult], None]: + """In each round, yield up to `tests_per_device` tests for each device.""" + while any(device_to_tests.values()): + for device, tests in device_to_tests.items(): + count = min(tests_per_device, len(tests)) + for _ in range(count): + test_def = tests.pop(0) + coro = self._create_test_coroutine(test_def, device, manager) + if coro is not None: + yield coro + + def _create_test_coroutine( + self, test_def: AntaTestDefinition, device: AntaDevice, manager: ResultManager | None = None + ) -> Coroutine[Any, Any, TestResult] | None: + """Create a test coroutine from a test definition.""" + try: + test_instance = test_def.test(device=device, inputs=test_def.inputs) + if manager is not None: + manager.add(test_instance.result) + coroutine = test_instance.test() + except Exception as e: # noqa: BLE001 + # An AntaTest instance is potentially user-defined code. + # We need to catch everything and exit gracefully with an error message. + message = "\n".join( + [ + f"There is an error when creating test {test_def.test.__module__}.{test_def.test.__name__}.", + f"If this is not a custom test implementation: {GITHUB_SUGGESTION}", + ], + ) + anta_log_exception(e, message, logger) + return None + return coroutine + + def _log_run_information(self, *, dry_run: bool = False) -> None: + """Log ANTA run information and potential resource limit warnings.""" + if self._inventory_stats is None: + msg = "The inventory stats are not available. ANTA must be executed through AntaRunner.run()" + raise RuntimeError(msg) + + width = min(int(console.width) - 34, len(" Total potential connections: 100000000\n")) + + # Build device information + device_lines = [ + "Devices:", + f" Total: {self._inventory_stats.total}", + ] + if self._inventory_stats.filtered_by_tags > 0: + device_lines.append(f" Excluded by tags: {self._inventory_stats.filtered_by_tags}") + if self._inventory_stats.connection_failed > 0: + device_lines.append(f" Failed to connect: {self._inventory_stats.connection_failed}") + device_lines.append(f" Selected: {self._inventory_stats.established}{' (dry-run mode)' if dry_run else ''}") + + # Build connection information + connections_line = "" if self._potential_connections is None else f" Total potential connections: {self._potential_connections}\n" + + run_info = ( + f"{' ANTA NRFU Run Information ':-^{width}}\n" + f"{chr(10).join(device_lines)}\n" + f"Tests: {self._total_tests} total\n" + f"Limits:\n" + f" Max concurrent tests: {self._settings.max_concurrency}\n" + f"{connections_line}" + f" Max file descriptors: {self._settings.file_descriptor_limit}\n" + f"{'':-^{width}}" + ) + logger.info(run_info) + + # Log warnings for potential resource limits + if self._total_tests > self._settings.max_concurrency: + logger.warning( + "Tests count (%s) exceeds concurrent limit (%s). Tests will be throttled. Please consult the ANTA FAQ.", + self._total_tests, + self._settings.max_concurrency, + ) + if self._potential_connections is not None and self._potential_connections > self._settings.file_descriptor_limit: + logger.warning( + "Potential connections (%s) exceeds file descriptor limit (%s). Connection errors may occur. Please consult the ANTA FAQ.", + self._potential_connections, + self._settings.file_descriptor_limit, + ) + + def _log_cache_statistics(self) -> None: + """Log cache statistics for each device in the inventory.""" + if self._selected_inventory is None: + msg = "The selected inventory is not available. ANTA must be executed through AntaRunner.run()" + raise RuntimeError(msg) + + for device in self._selected_inventory.devices: + if device.cache_statistics is not None: + msg = ( + f"Cache statistics for '{device.name}': " + f"{device.cache_statistics['cache_hits']} hits / {device.cache_statistics['total_commands_sent']} " + f"command(s) ({device.cache_statistics['cache_hit_ratio']})" + ) + logger.info(msg) + else: + logger.info("Caching is not enabled on %s", device.name) diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index 60c0d2976..1928d2f89 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -50,6 +50,7 @@ def run_tests(ctx: click.Context) -> None: print_settings(inventory, catalog) with anta_progress_bar() as AntaTest.progress: + # TODO: Use AntaRunner in ANTA v2.0.0 asyncio.run( main( ctx.obj["result_manager"], diff --git a/anta/cli/utils.py b/anta/cli/utils.py index 508424dd0..71f4819ce 100644 --- a/anta/cli/utils.py +++ b/anta/cli/utils.py @@ -163,6 +163,7 @@ def core_options(f: Callable[..., Any]) -> Callable[..., Any]: show_envvar=True, envvar="ANTA_TIMEOUT", show_default=True, + type=float, ) @click.option( "--insecure", diff --git a/anta/device.py b/anta/device.py index 7c1e6f642..502360f0f 100644 --- a/anta/device.py +++ b/anta/device.py @@ -9,6 +9,7 @@ import logging from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict +from operator import attrgetter from time import monotonic from typing import TYPE_CHECKING, Any, Literal @@ -114,16 +115,19 @@ class AntaDevice(ABC): True if the device IP is reachable and a port can be open. established : bool True if remote command execution succeeds. - hw_model : str + hw_model : str | None Hardware model of the device. tags : set[str] Tags for this device. cache : AntaCache | None In-memory cache for this device (None if cache is disabled). - cache_locks : dict + cache_locks : defaultdict[str, asyncio.Lock] | None Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled. Deprecated, will be removed in ANTA v2.0.0, use self.cache.locks instead. - + max_connections : int | None + For informational/logging purposes only. Can be used by the runner to verify that + the total potential connections of a run do not exceed the system's file descriptor limit. + This does **not** affect the actual device configuration. None if not available. """ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bool = False) -> None: @@ -159,6 +163,11 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo def _keys(self) -> tuple[Any, ...]: """Read-only property to implement hashing and equality for AntaDevice classes.""" + @property + def max_connections(self) -> int | None: + """Maximum number of concurrent connections allowed by the device. Can be overridden by subclasses, returns None if not available.""" + return None + def __eq__(self, other: object) -> bool: """Implement equality for AntaDevice objects.""" return self._keys == other._keys if isinstance(other, self.__class__) else False @@ -304,7 +313,7 @@ async def copy(self, sources: list[Path], destination: Path, direction: Literal[ # pylint: disable=too-many-instance-attributes class AsyncEOSDevice(AntaDevice): - """Implementation of AntaDevice for EOS using aio-eapi. + """Implementation of AntaDevice for EOS using the `asynceapi` library, which is built on HTTPX. Attributes ---------- @@ -318,7 +327,6 @@ class AsyncEOSDevice(AntaDevice): Hardware model of the device. tags : set[str] Tags for this device. - """ def __init__( # noqa: PLR0913 @@ -329,9 +337,9 @@ def __init__( # noqa: PLR0913 name: str | None = None, enable_password: str | None = None, port: int | None = None, - ssh_port: int | None = 22, + ssh_port: int = 22, tags: set[str] | None = None, - timeout: float | None = None, + timeout: float | None = 30.0, proto: Literal["http", "https"] = "https", *, enable: bool = False, @@ -350,8 +358,6 @@ def __init__( # noqa: PLR0913 Password to connect to eAPI and SSH. name Device name. - enable - Collect commands using privileged mode. enable_password Password used to gain privileged access on EOS. port @@ -361,14 +367,15 @@ def __init__( # noqa: PLR0913 tags Tags for this device. timeout - Timeout value in seconds for outgoing API calls. - insecure - Disable SSH Host Key validation. + Global timeout value in seconds for outgoing eAPI calls. None means no timeout. proto eAPI protocol. Value can be 'http' or 'https'. + enable + Collect commands using privileged mode. + insecure + Disable SSH Host Key validation. disable_cache Disable caching for all commands for this device. - """ if host is None: message = "'host' is required to create an AsyncEOSDevice" @@ -385,15 +392,26 @@ def __init__( # noqa: PLR0913 message = f"'password' is required to instantiate device '{self.name}'" logger.error(message) raise ValueError(message) + self.enable = enable self._enable_password = enable_password - self._session: asynceapi.Device = asynceapi.Device(host=host, port=port, username=username, password=password, proto=proto, timeout=timeout) - ssh_params: dict[str, Any] = {} + + # Build the session settings for the `asynceapi` client + session_settings: dict[str, Any] = { + "host": host, + "port": port, + "username": username, + "password": password, + "proto": proto, + "timeout": timeout, + } + self._session: asynceapi.Device = asynceapi.Device(**session_settings) + + # Build the SSH connection options + ssh_settings = {"host": host, "port": ssh_port, "username": username, "password": password, "client_keys": CLIENT_KEYS} if insecure: - ssh_params["known_hosts"] = None - self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions( - host=host, port=ssh_port, username=username, password=password, client_keys=CLIENT_KEYS, **ssh_params - ) + ssh_settings["known_hosts"] = None + self._ssh_opts: SSHClientConnectionOptions = SSHClientConnectionOptions(**ssh_settings) # In Python 3.9, Semaphore must be created within a running event loop # TODO: Once we drop Python 3.9 support, initialize the semaphore here @@ -417,6 +435,7 @@ def __rich_repr__(self) -> Iterator[tuple[str, Any]]: _ssh_opts["kwargs"]["password"] = removed_pw yield ("_session", vars(self._session)) yield ("_ssh_opts", _ssh_opts) + yield ("max_connections", self.max_connections) if self.max_connections is not None else ("max_connections", "N/A") def __repr__(self) -> str: """Return a printable representation of an AsyncEOSDevice.""" @@ -442,6 +461,14 @@ def _keys(self) -> tuple[Any, ...]: """ return (self._session.host, self._session.port) + @property + def max_connections(self) -> int | None: + """Maximum number of concurrent connections allowed by the device. Returns None if not available.""" + try: + return attrgetter("_transport._pool._max_connections")(self._session) + except AttributeError: + return None + async def _get_semaphore(self) -> asyncio.Semaphore: """Return the semaphore, initializing it if needed. diff --git a/anta/inventory/__init__.py b/anta/inventory/__init__.py index f98c42f29..b3afc3f8f 100644 --- a/anta/inventory/__init__.py +++ b/anta/inventory/__init__.py @@ -177,7 +177,7 @@ def parse( username: str, password: str, enable_password: str | None = None, - timeout: float | None = None, + timeout: float | None = 30.0, *, enable: bool = False, insecure: bool = False, @@ -198,7 +198,7 @@ def parse( enable_password Enable password to use if required. timeout - Timeout value in seconds for outgoing API calls. + Global timeout value in seconds for outgoing eAPI calls. None means no timeout. enable Whether or not the commands need to be run in enable mode towards the devices. insecure diff --git a/anta/reporter/__init__.py b/anta/reporter/__init__.py index 5156ea7e8..b6f061c10 100644 --- a/anta/reporter/__init__.py +++ b/anta/reporter/__init__.py @@ -29,7 +29,7 @@ class ReportTable: """TableReport Generate a Table based on TestResult.""" @dataclass - class Headers: # pylint: disable=too-many-instance-attributes + class Headers: """Headers for the table report.""" device: str = "Device" diff --git a/anta/result_manager/__init__.py b/anta/result_manager/__init__.py index d0a348861..7e2c71560 100644 --- a/anta/result_manager/__init__.py +++ b/anta/result_manager/__init__.py @@ -21,7 +21,6 @@ logger = logging.getLogger(__name__) -# pylint: disable=too-many-instance-attributes class ResultManager: """Manager of ANTA Results. diff --git a/anta/result_manager/models.py b/anta/result_manager/models.py index a18ff579c..a19c969de 100644 --- a/anta/result_manager/models.py +++ b/anta/result_manager/models.py @@ -122,8 +122,6 @@ def __str__(self) -> str: return f"Test '{self.test}' (on '{self.name}'): Result '{self.result}'\nMessages: {self.messages}" -# Pylint does not treat dataclasses differently: https://github.com/pylint-dev/pylint/issues/9058 -# pylint: disable=too-many-instance-attributes @dataclass class DeviceStats: """Device statistics for a run of tests.""" diff --git a/anta/runner.py b/anta/runner.py index 84e27a133..15d5f0280 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -5,16 +5,16 @@ from __future__ import annotations -import asyncio import logging import os -import sys from collections import defaultdict from typing import TYPE_CHECKING, Any +from typing_extensions import deprecated + from anta import GITHUB_SUGGESTION +from anta._runner import AntaRunner, AntaRunnerFilter from anta.logger import anta_log_exception, exc_to_str -from anta.models import AntaTest from anta.tools import Catchtime, cprofile if TYPE_CHECKING: @@ -31,6 +31,7 @@ DEFAULT_NOFILE = 16384 + @deprecated("This function is deprecated and will be removed in ANTA v2.0.0. Use AntaRunner class instead.", category=DeprecationWarning) def adjust_rlimit_nofile() -> tuple[int, int]: """Adjust the maximum number of open file descriptors for the ANTA process. @@ -60,6 +61,7 @@ def adjust_rlimit_nofile() -> tuple[int, int]: logger = logging.getLogger(__name__) +@deprecated("This function is deprecated and will be removed in ANTA v2.0.0. Use AntaRunner class instead.", category=DeprecationWarning) def log_cache_statistics(devices: list[AntaDevice]) -> None: """Log cache statistics for each device in the inventory. @@ -80,6 +82,7 @@ def log_cache_statistics(devices: list[AntaDevice]) -> None: logger.info("Caching is not enabled on %s", device.name) +@deprecated("This function is deprecated and will be removed in ANTA v2.0.0. Use AntaRunner class instead.", category=DeprecationWarning) async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devices: set[str] | None, *, established_only: bool) -> AntaInventory | None: """Set up the inventory for the ANTA run. @@ -122,6 +125,7 @@ async def setup_inventory(inventory: AntaInventory, tags: set[str] | None, devic return selected_inventory +@deprecated("This function is deprecated and will be removed in ANTA v2.0.0. Use AntaRunner class instead.", category=DeprecationWarning) def prepare_tests( inventory: AntaInventory, catalog: AntaCatalog, tests: set[str] | None, tags: set[str] | None ) -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None: @@ -178,6 +182,7 @@ def prepare_tests( return device_to_tests +@deprecated("This function is deprecated and will be removed in ANTA v2.0.0. Use AntaRunner class instead.", category=DeprecationWarning) def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinition]], manager: ResultManager | None = None) -> list[Coroutine[Any, Any, TestResult]]: """Get the coroutines for the ANTA run. @@ -250,62 +255,11 @@ async def main( dry_run Build the list of coroutine to run and stop before test execution. """ - if not catalog.tests: - logger.info("The list of tests is empty, exiting") - return - - with Catchtime(logger=logger, message="Preparing ANTA NRFU Run"): - # Setup the inventory - selected_inventory = inventory if dry_run else await setup_inventory(inventory, tags, devices, established_only=established_only) - if selected_inventory is None: - return - - with Catchtime(logger=logger, message="Preparing the tests"): - selected_tests = prepare_tests(selected_inventory, catalog, tests, tags) - if selected_tests is None: - return - final_tests_count = sum(len(tests) for tests in selected_tests.values()) - - run_info = ( - "--- ANTA NRFU Run Information ---\n" - f"Number of devices: {len(inventory)} ({len(selected_inventory)} established)\n" - f"Total number of selected tests: {final_tests_count}\n" - ) - - if os.name == "posix": - # Adjust the maximum number of open file descriptors for the ANTA process - limits = adjust_rlimit_nofile() - run_info += f"Maximum number of open file descriptors for the current ANTA process: {limits[0]}\n" - else: - # Running on non-Posix system, cannot manage the resource. - limits = (sys.maxsize, sys.maxsize) - run_info += "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors.\n" - - run_info += "---------------------------------" - - logger.info(run_info) - - if final_tests_count > limits[0]: - logger.warning( - "The number of concurrent tests is higher than the open file descriptors limit for this ANTA process.\n" - "Errors may occur while running the tests.\n" - "Please consult the ANTA FAQ." - ) - - coroutines = get_coroutines(selected_tests, manager if dry_run else None) - - if dry_run: - logger.info("Dry-run mode, exiting before running the tests.") - for coro in coroutines: - coro.close() - return - - if AntaTest.progress is not None: - AntaTest.nrfu_task = AntaTest.progress.add_task("Running NRFU Tests...", total=len(coroutines)) - - with Catchtime(logger=logger, message="Running ANTA tests"): - results = await asyncio.gather(*coroutines) - for result in results: - manager.add(result) - - log_cache_statistics(selected_inventory.devices) + runner = AntaRunner(inventory=inventory, catalog=catalog, manager=manager) + scope = AntaRunnerFilter( + devices=devices, + tests=tests, + tags=tags, + established_only=established_only, + ) + await runner.run(scope, dry_run=dry_run) diff --git a/anta/settings.py b/anta/settings.py new file mode 100644 index 000000000..fb574d679 --- /dev/null +++ b/anta/settings.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Settings for ANTA.""" + +from __future__ import annotations + +import logging +import os +import sys +from enum import Enum +from typing import Any + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings, SettingsConfigDict + +logger = logging.getLogger(__name__) + +# Default value for the maximum number of concurrent tests in the event loop +DEFAULT_MAX_CONCURRENCY = 10000 + +# Default value for the maximum number of open file descriptors for the ANTA process +DEFAULT_NOFILE = 16384 + +# Default value for the test scheduling strategy +DEFAULT_SCHEDULING_STRATEGY = "round-robin" + +# Default value for the number of tests to schedule per device when using the DEVICE_BY_COUNT scheduling strategy +DEFAULT_SCHEDULING_TESTS_PER_DEVICE = 100 + + +class AntaRunnerSchedulingStrategy(str, Enum): + """Enum for the test scheduling strategies available in the ANTA runner. + + * ROUND_ROBIN: Distribute tests across devices in a round-robin fashion. + * DEVICE_BY_DEVICE: Run all tests on a single device before moving to the next. + * DEVICE_BY_COUNT: Run all tests on a single device for a specified count before moving to the next. + + NOTE: This could be updated to StrEnum when Python 3.11 is the minimum supported version in ANTA. + """ + + ROUND_ROBIN = "round-robin" + DEVICE_BY_DEVICE = "device-by-device" + DEVICE_BY_COUNT = "device-by-count" + + def __str__(self) -> str: + """Override the __str__ method to return the value of the Enum, mimicking the behavior of StrEnum.""" + return self.value + + +class AntaRunnerSettings(BaseSettings): + """Environment variables for configuring the ANTA runner. + + When initialized, relevant environment variables are loaded. If not set, default values are used. + + On POSIX systems, also adjusts the process's soft limit based on the `ANTA_NOFILE` environment variable + while respecting the system's hard limit, meaning the new soft limit cannot exceed the system's hard limit. + + On non-POSIX systems (Windows), sets the limit to `sys.maxsize`. + + The adjusted limit is available with the `file_descriptor_limit` property after initialization. + + Attributes + ---------- + nofile : PositiveInt + Environment variable: ANTA_NOFILE + + The maximum number of open file descriptors for the ANTA process. Defaults to 16384. + + max_concurrency : PositiveInt + Environment variable: ANTA_MAX_CONCURRENCY + + The maximum number of concurrent tests that can run in the event loop. Defaults to 10000. + + scheduling_strategy : AntaRunnerSchedulingStrategy + Environment variable: ANTA_SCHEDULING_STRATEGY + + The test scheduling strategy to use when running tests. Defaults to "round-robin". + + scheduling_tests_per_device : PositiveInt + Environment variable: ANTA_SCHEDULING_TESTS_PER_DEVICE + + The number of tests to schedule per device when using the `DEVICE_BY_COUNT` scheduling strategy. Defaults to 100. + """ + + model_config = SettingsConfigDict(env_prefix="ANTA_") + + nofile: PositiveInt = Field(default=DEFAULT_NOFILE) + max_concurrency: PositiveInt = Field(default=DEFAULT_MAX_CONCURRENCY) + scheduling_strategy: AntaRunnerSchedulingStrategy = Field(default=AntaRunnerSchedulingStrategy(DEFAULT_SCHEDULING_STRATEGY)) + scheduling_tests_per_device: PositiveInt = Field(default=DEFAULT_SCHEDULING_TESTS_PER_DEVICE) + + # Computed in post-init + _file_descriptor_limit: PositiveInt + + def model_post_init(self, __context: Any) -> None: # noqa: ANN401, PYI063 + """Post-initialization method to set the file descriptor limit for the current ANTA process.""" + if os.name != "posix": + logger.warning("Running on a non-POSIX system, cannot adjust the maximum number of file descriptors.") + self._file_descriptor_limit = sys.maxsize + return + + import resource + + limits = resource.getrlimit(resource.RLIMIT_NOFILE) + logger.debug("Initial file descriptor limits: Soft Limit: %s | Hard Limit: %s", limits[0], limits[1]) + + # Set new soft limit to minimum of requested and hard limit + new_soft_limit = min(limits[1], self.nofile) + logger.debug("Setting file descriptor soft limit to %s", new_soft_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft_limit, limits[1])) + + self._file_descriptor_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + return + + @property + def file_descriptor_limit(self) -> PositiveInt: + """The maximum number of file descriptors available to the process.""" + return self._file_descriptor_limit diff --git a/anta/tools.py b/anta/tools.py index a0a21a673..6904935ce 100644 --- a/anta/tools.py +++ b/anta/tools.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import cProfile import os import pstats @@ -19,6 +20,7 @@ if TYPE_CHECKING: import sys + from collections.abc import AsyncIterator, Coroutine from logging import Logger from types import TracebackType @@ -28,6 +30,7 @@ from typing_extensions import Self F = TypeVar("F", bound=Callable[..., Any]) +T = TypeVar("T") def get_failed_logs(expected_output: dict[Any, Any], actual_output: dict[Any, Any]) -> str: @@ -415,3 +418,45 @@ def format_data(data: dict[str, bool]) -> str: "Advertised: True, Received: True, Enabled: True" """ return ", ".join(f"{k.capitalize()}: {v}" for k, v in data.items()) + + +async def limit_concurrency(coroutines: AsyncIterator[Coroutine[Any, Any, T]], limit: int) -> AsyncIterator[asyncio.Task[T]]: + """Schedule a limited number of coroutines concurrently. + + Inspired by: https://death.andgravity.com/limit-concurrency + + Parameters + ---------- + coroutines + An async iterator of coroutines. + limit + The maximum number of coroutines to run concurrently. + + Yields + ------ + Each completed task. + """ + if limit <= 0: + msg = "Concurrency limit must be greater than 0." + raise RuntimeError(msg) + + coros_ended = False + pending: set[asyncio.Task[T]] = set() + + while pending or not coros_ended: + while len(pending) < limit and not coros_ended: + try: + # NOTE: The `anext` built-in function is not available in Python 3.9 + coro = await coroutines.__anext__() # pylint: disable=unnecessary-dunder-call + except StopAsyncIteration: # noqa: PERF203 + coros_ended = True + else: + pending.add(asyncio.create_task(coro)) + + if not pending: + return + + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + + while done: + yield done.pop() diff --git a/docs/api/settings.md b/docs/api/settings.md new file mode 100644 index 000000000..bbc45cae4 --- /dev/null +++ b/docs/api/settings.md @@ -0,0 +1,13 @@ +--- +anta_title: ANTA Settings +--- + + +### ::: anta.settings + + options: + show_root_full_path: true diff --git a/docs/faq.md b/docs/faq.md index 204972452..15ea6a4c1 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -30,7 +30,7 @@ anta_title: Frequently Asked Questions (FAQ) This usually means that the operating system refused to open a new file descriptor (or socket) for the ANTA process. This might be due to the hard limit for open file descriptors currently set for the ANTA process. - At startup, ANTA sets the soft limit of its process to the hard limit up to 16384. This is because the soft limit is usually 1024 and the hard limit is usually higher (depends on the system). If the hard limit of the ANTA process is still lower than the number of selected tests in ANTA, the ANTA process may request to the operating system too many file descriptors and get an error, a WARNING is displayed at startup if this is the case. + At startup, ANTA sets the soft limit of its process to the hard limit up to 16384. This is because the soft limit is usually 1024 and the hard limit is usually higher (depends on the system). If the hard limit of the ANTA process is still lower than the potential connections of all devices, the ANTA process may request to the operating system too many file descriptors and get an error, a WARNING is displayed at startup if this is the case. ### Solution @@ -43,11 +43,39 @@ anta_title: Frequently Asked Questions (FAQ) The `user` is the one with which the ANTA process is started. The `value` is the new hard limit. The maximum value depends on the system. A hard limit of 16384 should be sufficient for ANTA to run in most high scale scenarios. After creating this file, log out the current session and log in again. +## Tests throttling WARNING in the logs + +???+ faq "Tests throttling `WARNING` in the logs" + + ANTA is designed to execute many tests concurrently while ensuring system stability. If the total test count exceeds the maximum concurrency limit, tests are throttled to avoid overwhelming the asyncio event loop and exhausting system resources. A `WARNING` message is logged at startup when this occurs. + + By default, ANTA schedules up to **10000** tests concurrently. This default is a balance between performance and stability, but it may not be optimal for every system. If the number of tests exceeds this value, ANTA schedules the first 10000 tests and waits for some tests to complete before scheduling more. + + ### Solution + + You can adjust the maximum concurrency limit using the `ANTA_MAX_CONCURRENCY` environment variable. The optimal value depends on your system's CPU usage, memory consumption, and file descriptor limits. + + !!! warning + + Increasing the maximum concurrency limit can lead to system instability if the system is not able to handle the increased load. Monitor system resources and adjust the limit accordingly. + + !!! info "Device Connection Limits" + + Each device is limited to a maximum of **100** concurrent connections. This means that, even if ANTA schedules a high number of tests, each device will only attempt to open up to 100 connections at a time. Furthermore, Arista EOS eAPI is inherently protected against overload and is designed to handle high connection volumes safely. + + ANTA also offers several test scheduling strategies to optimize test execution, particularly relevant when the total number of tests exceeds the maximum concurrency limit. It is configurable via the `ANTA_SCHEDULING_STRATEGY` environment variable (default is `round-robin`), along with `ANTA_SCHEDULING_TESTS_PER_DEVICE` (default is **100**) for the `device-by-count` strategy: + + - **Round-robin (`round-robin`)**: Distributes tests evenly across devices. This is generally suitable for small to medium-sized (around 200 devices) fabrics but can open many simultaneous connections. + - **Device-by-Device (`device-by-device`)**: Executes all tests for one device before moving on to the next, which may help reduce peak concurrent connections. + - **Device-by-Count (`device-by-count`)**: Limits the number of tests scheduled per device in each round. This provides finer control in larger environments where opening too many connections simultaneously might exceed system limits. + + **Recommendation:** If you're running ANTA on a large fabric or encounter issues related to resource limits, considering tuning these settings. Test different configurations to find the optimal balance for your system. + ## `Timeout` error in the logs ???+ faq "`Timeout` error in the logs" - When running ANTA, you can receive `Timeout` errors in the logs (could be ReadTimeout, WriteTimeout, ConnectTimeout or PoolTimeout). More details on the timeouts of the underlying library are available here: https://www.python-httpx.org/advanced/timeouts. + When running ANTA, you can receive `Timeout` errors in the logs (could be `ReadTimeout`, `WriteTimeout`, `ConnectTimeout` or `PoolTimeout`). More details on the timeouts of the underlying library are available here: https://www.python-httpx.org/advanced/timeouts. This might be due to the time the host on which ANTA is run takes to reach the target devices (for instance if going through firewalls, NATs, ...) or when a lot of tests are being run at the same time on a device (eAPI has a queue mechanism to avoid exhausting EOS resources because of a high number of simultaneous eAPI requests). @@ -59,8 +87,7 @@ anta_title: Frequently Asked Questions (FAQ) anta nrfu --enable --username username --password arista --inventory inventory.yml -c nrfu.yml --timeout 50 text ``` - The previous command set a couple of options for ANTA NRFU, one them being the `timeout` command, by default, when running ANTA from CLI, it is set to 30s. - The timeout is increased to 50s to allow ANTA to wait for API calls a little longer. + In this command, ANTA NRFU is configured with several options. Notably, the `--timeout` parameter is set to 50 seconds (instead of the default 30 seconds) to allow extra time for API calls to complete. ## `ImportError` related to `urllib3` diff --git a/mkdocs.yml b/mkdocs.yml index 5b3c6dcb7..10d5f0267 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -2,7 +2,7 @@ site_name: Arista Network Test Automation - ANTA site_author: Khelil Sator site_description: Arista Network Test Automation -copyright: Copyright © 2019 - 2024 Arista Networks +copyright: Copyright © 2019 - 2025 Arista Networks # Repository repo_name: ANTA on Github @@ -185,9 +185,9 @@ nav: - Debug commands: cli/debug.md - Tag Management: cli/tag-management.md - Advanced Usages: - - Caching in ANTA: advanced_usages/caching.md - - Developing ANTA tests: advanced_usages/custom-tests.md - - ANTA as a Python Library: advanced_usages/as-python-lib.md + - Caching in ANTA: advanced_usages/caching.md + - Developing ANTA tests: advanced_usages/custom-tests.md + - ANTA as a Python Library: advanced_usages/as-python-lib.md - Tests Documentation: - Overview: api/tests.md - AAA: api/tests/aaa.md @@ -239,6 +239,7 @@ nav: - CSV: api/reporter/csv.md - Jinja: api/reporter/jinja.md - Runner: api/runner.md + - Settings: api/settings.md - Troubleshooting ANTA: troubleshooting.md - Contributions: contribution.md - FAQ: faq.md diff --git a/pyproject.toml b/pyproject.toml index 4c37604e9..ecbb8da8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "Jinja2>=3.1.2", "pydantic>=2.7", "pydantic-extra-types>=2.3.0", + "pydantic-settings>=2.6.0", "PyYAML>=6.0", "requests>=2.31.0", "rich>=13.5.2,<14" @@ -487,3 +488,5 @@ min-similarity-lines=10 signature-mutators="click.decorators.option" load-plugins="pylint_pydantic" extension-pkg-whitelist="pydantic" +# Pylint does not treat dataclasses differently: https://github.com/pylint-dev/pylint/issues/9058 +max-attributes=15 diff --git a/tests/benchmark/test_anta.py b/tests/benchmark/test_anta.py index 1daf7f369..346ebfd89 100644 --- a/tests/benchmark/test_anta.py +++ b/tests/benchmark/test_anta.py @@ -37,6 +37,7 @@ def test_anta_dry_run( results = session_results[request.node.callspec.id] + # TODO: Use AntaRunner in ANTA v2.0.0 @benchmark def _() -> None: results.reset() @@ -69,6 +70,7 @@ def test_anta( results = session_results[request.node.callspec.id] + # TODO: Use AntaRunner in ANTA v2.0.0 @benchmark def _() -> None: results.reset() diff --git a/tests/benchmark/test_runner.py b/tests/benchmark/test_runner.py index 9aa54df27..94c84e8eb 100644 --- a/tests/benchmark/test_runner.py +++ b/tests/benchmark/test_runner.py @@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any +import pytest + +from anta._runner import AntaRunner, AntaRunnerFilter from anta.result_manager import ResultManager from anta.runner import get_coroutines, prepare_tests @@ -22,6 +25,8 @@ from anta.result_manager.models import TestResult +# TODO: Remove this in ANTA v2.0.0 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_prepare_tests(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: """Benchmark `anta.runner.prepare_tests`.""" @@ -36,6 +41,8 @@ def _() -> defaultdict[AntaDevice, set[AntaTestDefinition]] | None: assert sum(len(tests) for tests in selected_tests.values()) == len(inventory) * len(catalog.tests) +# TODO: Remove this in ANTA v2.0.0 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_get_coroutines(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: """Benchmark `anta.runner.get_coroutines`.""" selected_tests = prepare_tests(inventory=inventory, catalog=catalog, tests=None, tags=None) @@ -52,3 +59,19 @@ def bench() -> list[Coroutine[Any, Any, TestResult]]: count = sum(len(tests) for tests in selected_tests.values()) assert count == len(coroutines) + + +def test_setup_tests(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: + """Benchmark `anta._runner.AntaRunner._setup_tests`.""" + runner = AntaRunner(inventory=inventory, catalog=catalog) + runner._selected_inventory = inventory + + def bench() -> bool: + catalog.clear_indexes() + return runner._setup_tests(filters=AntaRunnerFilter()) + + benchmark(bench) + + assert runner._selected_tests is not None + assert len(runner._selected_tests) == len(inventory) + assert sum(len(tests) for tests in runner._selected_tests.values()) == len(inventory) * len(catalog.tests) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 49c786f00..bde8577cc 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -5,20 +5,26 @@ from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING, Any +from unittest import mock from unittest.mock import patch import pytest import yaml +from anta.catalog import AntaCatalog from anta.device import AntaDevice, AsyncEOSDevice +from anta.inventory import AntaInventory if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Generator, Iterator + from anta._runner import AntaRunner from anta.models import AntaCommand +DATA_DIR: Path = Path(__file__).parent.parent.resolve() / "data" DEVICE_HW_MODEL = "pytest" DEVICE_NAME = "pytest" COMMAND_OUTPUT = "retrieved" @@ -83,3 +89,64 @@ def yaml_file(request: pytest.FixtureRequest, tmp_path: Path) -> Path: content: dict[str, Any] = request.param file.write_text(yaml.dump(content, allow_unicode=True)) return file + + +@pytest.fixture +def anta_runner(request: pytest.FixtureRequest) -> AntaRunner: + """AntaRunner fixture. + + Must be parametrized with a dictionary containing the following keys: + - inventory: Inventory file name from the data directory + - catalog: Catalog file name from the data directory + + Optional keys: + - manager: ResultManager instance + - max_concurrency: Maximum concurrency limit + - nofile: File descriptor limit + """ + # Import must be inside fixture to prevent circular dependency from breaking CLI tests: + # anta.runner -> anta.cli.console -> anta.cli/* (not yet loaded) -> anta.cli.anta + from anta._runner import AntaRunner + from anta.settings import AntaRunnerSettings + + if not hasattr(request, "param"): + msg = "anta_runner fixture requires a parameter dictionary" + raise ValueError(msg) + + params = request.param + + # Check required parameters + required_params = {"inventory", "catalog"} + missing_params = required_params - params.keys() + if missing_params: + msg = f"runner_context fixture missing required parameters: {missing_params}" + raise ValueError(msg) + + # Build AntaRunner fields + runner_fields = { + "inventory": AntaInventory.parse( + filename=DATA_DIR / params["inventory"], + username="arista", + password="arista", + ), + "catalog": AntaCatalog.parse(DATA_DIR / params["catalog"]), + "manager": params.get("manager", None), + } + + # Build AntaRunnerSettings fields + settings_fields = {} + if "max_concurrency" in params: + settings_fields["max_concurrency"] = params["max_concurrency"] + if "nofile" in params: + settings_fields["nofile"] = params["nofile"] + + runner = AntaRunner(**runner_fields) + runner._settings = AntaRunnerSettings(**settings_fields) + return runner + + +@pytest.fixture +def setenvvar(monkeypatch: pytest.MonkeyPatch) -> Generator[pytest.MonkeyPatch, None, None]: + """Fixture to set environment variables for testing.""" + with mock.patch.dict(os.environ, clear=True): + yield monkeypatch diff --git a/tests/units/test__runner.py b/tests/units/test__runner.py new file mode 100644 index 000000000..0c82a6dd6 --- /dev/null +++ b/tests/units/test__runner.py @@ -0,0 +1,310 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Test anta._runner.py.""" + +from __future__ import annotations + +import logging +import os + +import pytest +from pydantic import ValidationError + +from anta._runner import AntaRunner, AntaRunnerFilter +from anta.result_manager import ResultManager +from anta.settings import AntaRunnerSchedulingStrategy + + +class TestAntaRunnerBasic: + """Test AntaRunner basic functionality.""" + + @pytest.mark.parametrize( + ("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml", "manager": ResultManager()}], indirect=True + ) + def test_init(self, anta_runner: AntaRunner) -> None: + """Test basic initialization.""" + assert anta_runner.manager is not None + assert len(anta_runner.inventory.devices) == 3 + assert len(anta_runner.catalog.tests) == 11 + assert len(anta_runner.manager.results) == 0 + + # Check private attributes are initialized + assert anta_runner._selected_inventory is None + assert anta_runner._selected_tests is None + assert anta_runner._inventory_stats is None + assert anta_runner._total_tests == 0 + assert anta_runner._potential_connections is None + + # Check default settings + assert anta_runner._settings.max_concurrency == 10000 + assert anta_runner._settings.scheduling_strategy == "round-robin" + assert anta_runner._settings.scheduling_tests_per_device == 100 + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_reset(self, anta_runner: AntaRunner) -> None: + """Test AntaRunner.reset method.""" + await anta_runner.run(dry_run=True) + + # After a run, the following attributes should be set + assert anta_runner._selected_inventory is not None + assert anta_runner._selected_tests is not None + assert anta_runner._inventory_stats is not None + assert anta_runner._total_tests != 0 + assert anta_runner._potential_connections is not None + + anta_runner.reset() + + # After reset, the following attributes should be None + assert anta_runner._selected_inventory is None + assert anta_runner._selected_tests is None + assert anta_runner._inventory_stats is None + assert anta_runner._total_tests == 0 + assert anta_runner._potential_connections is None + + +class TestAntaRunnerRun: + """Test AntaRunner.run method.""" + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_run_dry_run(self, caplog: pytest.LogCaptureFixture, anta_runner: AntaRunner) -> None: + """Test AntaRunner.run method in dry-run mode.""" + caplog.set_level(logging.INFO) + + await anta_runner.run(dry_run=True) + + # In dry-run mode, the selected inventory is the original inventory + assert anta_runner._selected_inventory is not None + assert len(anta_runner._selected_inventory) == len(anta_runner.inventory) + + # In dry-run mode, the inventory stats total should match the original inventory length + assert anta_runner._inventory_stats is not None + assert anta_runner._inventory_stats.total == len(anta_runner.inventory) + + assert "Dry-run mode, exiting before running the tests." in caplog.records[-1].message + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_run_invalid_filters(self, anta_runner: AntaRunner) -> None: + """Test AntaRunner.run method with invalid scope.""" + with pytest.raises(ValidationError, match="1 validation error for AntaRunnerFilter"): + await anta_runner.run(filters=AntaRunnerFilter(devices="invalid", tests=None, tags=None), dry_run=True) # type: ignore[arg-type] + + @pytest.mark.parametrize( + ("anta_runner", "filters", "expected_devices", "expected_tests"), + [ + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests=None, tags=None), + 3, + 27, + id="all-tests", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests=None, tags={"leaf"}), + 2, + 6, + id="1-tag", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests=None, tags={"leaf", "spine"}), + 3, + 9, + id="2-tags", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests={"VerifyMlagStatus", "VerifyUptime"}, tags=None), + 3, + 5, + id="filtered-tests", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests={"VerifyMlagStatus", "VerifyUptime"}, tags={"leaf"}), + 2, + 4, + id="1-tag-filtered-tests", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests=None, tags={"invalid"}), + 0, + 0, + id="invalid-tag", + ), + pytest.param( + {"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}, + AntaRunnerFilter(devices=None, tests=None, tags={"dc1"}), + 0, + 0, + id="device-tag-no-tests", + ), + ], + indirect=["anta_runner"], + ) + async def test_run_filters( + self, caplog: pytest.LogCaptureFixture, anta_runner: AntaRunner, filters: AntaRunnerFilter, expected_devices: int, expected_tests: int + ) -> None: + """Test AntaRunner.run method with different filters.""" + caplog.set_level(logging.WARNING) + + await anta_runner.run(filters, dry_run=True) + + # Check when all tests are filtered out + if expected_devices == 0 and expected_tests == 0: + assert anta_runner._total_tests == 0 + assert anta_runner._selected_tests is None + msg = f"There are no tests matching the tags {filters.tags} to run in the current test catalog and device inventory, please verify your inputs." + assert msg in caplog.messages + return + + assert anta_runner._selected_tests is not None + assert len(anta_runner._selected_tests) == expected_devices + assert sum(len(tests) for tests in anta_runner._selected_tests.values()) == expected_tests + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_multiple_runs_no_manager(self, anta_runner: AntaRunner) -> None: + """Test multiple runs without a ResultManager instance.""" + assert anta_runner.manager is None + + first_run_manager = await anta_runner.run(dry_run=True) + assert isinstance(first_run_manager, ResultManager) + assert len(first_run_manager.results) == 27 + + second_run_manager = await anta_runner.run(dry_run=True) + assert isinstance(second_run_manager, ResultManager) + assert len(second_run_manager.results) == 27 + + @pytest.mark.parametrize( + ("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml", "manager": ResultManager()}], indirect=True + ) + async def test_multiple_runs_with_manager(self, anta_runner: AntaRunner) -> None: + """Test multiple runs with a provided ResultManager instance.""" + assert anta_runner.manager is not None + + first_run_manager = await anta_runner.run(dry_run=True) + assert len(first_run_manager.results) == 27 + assert first_run_manager.results == anta_runner.manager.results + + # When a manager is provided, results from subsequent runs are appended to the manager + second_run_manager = await anta_runner.run(dry_run=True) + assert len(second_run_manager.results) == 54 + assert first_run_manager.results == second_run_manager.results + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_run_device_by_device_strategy(self, anta_runner: AntaRunner) -> None: + """Test AntaRunner.run method with device-by-device scheduling strategy.""" + manager = ResultManager() + anta_runner._selected_inventory = anta_runner.inventory + anta_runner._setup_tests(filters=AntaRunnerFilter()) + anta_runner._settings.scheduling_strategy = AntaRunnerSchedulingStrategy.DEVICE_BY_DEVICE + + # Exhaust the generator and close the coroutines + async for coro in anta_runner._test_generator(manager): + coro.close() + + # Check that indices 0-8 all have name "leaf1" + assert all(result.name == "leaf1" for result in manager.results[0:9]) + + # Check that indices 9-17 all have name "leaf2" + assert all(result.name == "leaf2" for result in manager.results[9:18]) + + # Check that indices 18-26 all have name "spine1" + assert all(result.name == "spine1" for result in manager.results[18:26]) + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_run_device_by_count_strategy(self, anta_runner: AntaRunner) -> None: + """Test AntaRunner.run method with device-by-count scheduling strategy.""" + manager = ResultManager() + anta_runner._selected_inventory = anta_runner.inventory + anta_runner._setup_tests(filters=AntaRunnerFilter()) + anta_runner._settings.scheduling_strategy = AntaRunnerSchedulingStrategy.DEVICE_BY_COUNT + anta_runner._settings.scheduling_tests_per_device = 2 + + # Exhaust the generator and close the coroutines + async for coro in anta_runner._test_generator(manager): + coro.close() + + # Check that indices 0-1 all have name "leaf1" + assert all(result.name == "leaf1" for result in manager.results[0:2]) + + # Check that indices 2-3 all have name "leaf2" + assert all(result.name == "leaf2" for result in manager.results[2:4]) + + # Check that indices 4-5 all have name "spine1" + assert all(result.name == "spine1" for result in manager.results[4:6]) + + # The last 3 results should be "leaf1", "leaf2", "spine1" since there is no more tests to run + assert manager.results[-3].name == "leaf1" + assert manager.results[-2].name == "leaf2" + assert manager.results[-1].name == "spine1" + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_run_round_robin_strategy(self, anta_runner: AntaRunner) -> None: + """Test AntaRunner.run method with round-robin scheduling strategy.""" + manager = ResultManager() + anta_runner._selected_inventory = anta_runner.inventory + anta_runner._setup_tests(filters=AntaRunnerFilter()) + anta_runner._settings.scheduling_strategy = AntaRunnerSchedulingStrategy.ROUND_ROBIN + + # Exhaust the generator and close the coroutines + async for coro in anta_runner._test_generator(manager): + coro.close() + + # Round-robin between devices + assert manager.results[0].name == "leaf1" + assert manager.results[1].name == "leaf2" + assert manager.results[2].name == "spine1" + assert manager.results[3].name == "leaf1" + assert manager.results[4].name == "leaf2" + assert manager.results[5].name == "spine1" + + +class TestAntaRunnerLogging: + """Test AntaRunner logging.""" + + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml"}], indirect=True) + async def test_log_run_information_default(self, caplog: pytest.LogCaptureFixture, anta_runner: AntaRunner) -> None: + """Test _log_run_information with default values.""" + caplog.set_level(logging.INFO) + + await anta_runner.run(dry_run=True) + + expected_output = [ + "ANTA NRFU Run Information", + "Devices:", + " Total: 3", + " Selected: 0 (dry-run mode)", + "Tests: 27 total", + "Limits:", + " Max concurrent tests: 10000", + " Total potential connections: 300", + f" Max file descriptors: {anta_runner._settings.file_descriptor_limit}", + ] + for line in expected_output: + assert line in caplog.text + + @pytest.mark.parametrize( + ("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml", "max_concurrency": 20}], indirect=True + ) + async def test_log_run_information_concurrency_limit(self, caplog: pytest.LogCaptureFixture, anta_runner: AntaRunner) -> None: + """Test _log_run_information with higher tests count than concurrency limit.""" + caplog.set_level(logging.WARNING) + + await anta_runner.run(dry_run=True) + + warning = "Tests count (27) exceeds concurrent limit (20). Tests will be throttled." + assert warning in caplog.text + + @pytest.mark.skipif(os.name != "posix", reason="Veriy unlikely to happen on non-POSIX systems due to sys.maxsize") + @pytest.mark.parametrize(("anta_runner"), [{"inventory": "test_inventory_with_tags.yml", "catalog": "test_catalog_with_tags.yml", "nofile": 128}], indirect=True) + async def test_log_run_information_file_descriptor_limit(self, caplog: pytest.LogCaptureFixture, anta_runner: AntaRunner) -> None: + """Test _log_run_information with higher connections count than file descriptor limit.""" + caplog.set_level(logging.WARNING) + + await anta_runner.run(dry_run=True) + + warning = "Potential connections (300) exceeds file descriptor limit (128). Connection errors may occur." + assert warning in caplog.text diff --git a/tests/units/test_runner.py b/tests/units/test_runner.py index 1b9c40c88..6d8de26ae 100644 --- a/tests/units/test_runner.py +++ b/tests/units/test_runner.py @@ -30,25 +30,47 @@ FAKE_CATALOG: AntaCatalog = AntaCatalog.from_list([(FakeTest, None)]) +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 async def test_empty_tests(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: """Test that when the list of tests is empty, a log is raised.""" caplog.set_level(logging.INFO) manager = ResultManager() await main(manager, inventory, AntaCatalog()) - assert len(caplog.record_tuples) == 1 - assert "The list of tests is empty, exiting" in caplog.records[0].message + # On Windows, there is an extra log message when AntaRunner is initialized and + # tries to adjust the file descriptor limit. + if os.name != "posix": + record_tuples = 2 + record_index = 1 + else: + record_tuples = 1 + record_index = 0 + assert len(caplog.record_tuples) == record_tuples + assert "The list of tests is empty, exiting" in caplog.records[record_index].message + +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 async def test_empty_inventory(caplog: pytest.LogCaptureFixture) -> None: """Test that when the Inventory is empty, a log is raised.""" caplog.set_level(logging.INFO) manager = ResultManager() await main(manager, AntaInventory(), FAKE_CATALOG) - assert len(caplog.record_tuples) == 3 - assert "The inventory is empty, exiting" in caplog.records[1].message + + # On Windows, there is an extra log message when AntaRunner is initialized and + # tries to adjust the file descriptor limit. + if os.name != "posix": + record_tuples = 4 + record_index = 2 + else: + record_tuples = 3 + record_index = 1 + + assert len(caplog.record_tuples) == record_tuples + assert "The inventory is empty, exiting" in caplog.records[record_index].message +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 @pytest.mark.parametrize( ("inventory", "tags", "devices"), [ @@ -71,6 +93,8 @@ async def test_no_selected_device(caplog: pytest.LogCaptureFixture, inventory: A assert msg in caplog.messages +# TODO: Remove this in ANTA v2.0.0 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.skipif(os.name != "posix", reason="Cannot run this test on Windows") def test_adjust_rlimit_nofile_valid_env(caplog: pytest.LogCaptureFixture) -> None: """Test adjust_rlimit_nofile with valid environment variables.""" @@ -104,6 +128,8 @@ def side_effect_setrlimit(resource_id: int, limits: tuple[int, int]) -> None: setrlimit_mock.assert_called_once_with(resource.RLIMIT_NOFILE, (20480, 1048576)) +# TODO: Remove this in ANTA v2.0.0 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.skipif(os.name != "posix", reason="Cannot run this test on Windows") def test_adjust_rlimit_nofile_invalid_env(caplog: pytest.LogCaptureFixture) -> None: """Test adjust_rlimit_nofile with valid environment variables.""" @@ -138,6 +164,7 @@ def side_effect_setrlimit(resource_id: int, limits: tuple[int, int]) -> None: setrlimit_mock.assert_called_once_with(resource.RLIMIT_NOFILE, (16384, 1048576)) +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 @pytest.mark.skipif(os.name == "posix", reason="Run this test on Windows only") async def test_check_runner_log_for_windows(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: """Test log output for Windows host regarding rlimit.""" @@ -145,9 +172,10 @@ async def test_check_runner_log_for_windows(caplog: pytest.LogCaptureFixture, in manager = ResultManager() # Using dry-run to shorten the test await main(manager, inventory, FAKE_CATALOG, dry_run=True) - assert "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors." in caplog.records[-3].message + assert "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors." in caplog.records[0].message +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 # We could instead merge multiple coverage report together but that requires more work than just this. @pytest.mark.skipif(os.name != "posix", reason="Fake non-posix for coverage") async def test_check_runner_log_for_windows_fake(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: @@ -160,9 +188,11 @@ async def test_check_runner_log_for_windows_fake(caplog: pytest.LogCaptureFixtur manager = ResultManager() # Using dry-run to shorten the test await main(manager, inventory, FAKE_CATALOG, dry_run=True) - assert "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors." in caplog.records[-3].message + assert "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors." in caplog.records[0].message +# TODO: Remove this in ANTA v2.0.0 +@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( ("inventory", "tags", "tests", "devices_count", "tests_count"), [ @@ -192,6 +222,7 @@ async def test_prepare_tests( assert sum(len(tests) for tests in selected_tests.values()) == tests_count +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 async def test_dry_run(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: """Test that when dry_run is True, no tests are run.""" caplog.set_level(logging.INFO) @@ -200,6 +231,7 @@ async def test_dry_run(caplog: pytest.LogCaptureFixture, inventory: AntaInventor assert "Dry-run mode, exiting before running the tests." in caplog.records[-1].message +# TODO: Move this to AntaRunner tests in ANTA v2.0.0 async def test_cannot_create_test(caplog: pytest.LogCaptureFixture, inventory: AntaInventory) -> None: """Test that when an Exception is raised during test instantiation, it is caught and a log is raised.""" caplog.set_level(logging.CRITICAL) diff --git a/tests/units/test_settings.py b/tests/units/test_settings.py new file mode 100644 index 000000000..e8640ca33 --- /dev/null +++ b/tests/units/test_settings.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Unit tests for the anta.settings module.""" + +from __future__ import annotations + +import logging +import os +import sys +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from anta.settings import DEFAULT_MAX_CONCURRENCY, DEFAULT_NOFILE, DEFAULT_SCHEDULING_STRATEGY, DEFAULT_SCHEDULING_TESTS_PER_DEVICE, AntaRunnerSettings + +if os.name == "posix": + # The function is not defined on non-POSIX system + import resource + + +class TestAntaRunnerSettings: + """Tests for the FileDescriptiorSettings class.""" + + def test_defaults(self, setenvvar: pytest.MonkeyPatch) -> None: + """Test defaults for ANTA runner settings.""" + settings = AntaRunnerSettings() + assert settings.nofile == DEFAULT_NOFILE + assert settings.max_concurrency == DEFAULT_MAX_CONCURRENCY + assert settings.scheduling_strategy == DEFAULT_SCHEDULING_STRATEGY + assert settings.scheduling_tests_per_device == DEFAULT_SCHEDULING_TESTS_PER_DEVICE + + def test_env_var(self, setenvvar: pytest.MonkeyPatch) -> None: + """Test setting different ANTA runner settings.""" + setenvvar.setenv("ANTA_NOFILE", "20480") + setenvvar.setenv("ANTA_SCHEDULING_STRATEGY", "device-by-device") + setenvvar.setenv("ANTA_SCHEDULING_TESTS_PER_DEVICE", "50") + settings = AntaRunnerSettings() + assert settings.nofile == 20480 + assert settings.scheduling_strategy == "device-by-device" + assert settings.scheduling_tests_per_device == 50 + assert settings.max_concurrency == DEFAULT_MAX_CONCURRENCY + + def test_validation(self, setenvvar: pytest.MonkeyPatch) -> None: + """Test validation of ANTA runner settings.""" + setenvvar.setenv("ANTA_NOFILE", "-1") + with pytest.raises(ValidationError): + AntaRunnerSettings() + + setenvvar.setenv("ANTA_MAX_CONCURRENCY", "0") + with pytest.raises(ValidationError): + AntaRunnerSettings() + + setenvvar.setenv("ANTA_SCHEDULING_TESTS_PER_DEVICE", "unlimited") + with pytest.raises(ValidationError): + AntaRunnerSettings() + + setenvvar.setenv("ANTA_SCHEDULING_STRATEGY", "unlimited") + with pytest.raises(ValidationError): + AntaRunnerSettings() + + @pytest.mark.skipif(os.name == "posix", reason="Run this test on Windows only") + def test_file_descriptor_limit_windows(self, caplog: pytest.LogCaptureFixture) -> None: + """Test file_descriptor_limit on Windows.""" + caplog.set_level(logging.INFO) + settings = AntaRunnerSettings() + assert settings.file_descriptor_limit == sys.maxsize + assert "Running on a non-POSIX system, cannot adjust the maximum number of file descriptors." in caplog.text + + @pytest.mark.skipif(os.name != "posix", reason="Cannot run this test on Windows") + def test_file_descriptor_limit_posix(self, caplog: pytest.LogCaptureFixture) -> None: + """Test file_descriptor_limit on POSIX systems.""" + with ( + caplog.at_level(logging.DEBUG), + patch.dict("os.environ", {"ANTA_NOFILE": "20480"}), + patch("resource.getrlimit") as getrlimit_mock, + patch("resource.setrlimit") as setrlimit_mock, + ): + # Simulate the default system limits + system_limits = (8192, 1048576) + + # Setup getrlimit mock return value + getrlimit_mock.return_value = system_limits + + # Simulate setrlimit behavior + def side_effect_setrlimit(resource_id: int, limits: tuple[int, int]) -> None: + _ = resource_id + getrlimit_mock.return_value = (limits[0], limits[1]) + + setrlimit_mock.side_effect = side_effect_setrlimit + + settings = AntaRunnerSettings() + + # Assert the limits were updated as expected + assert settings.file_descriptor_limit == 20480 + assert "Initial file descriptor limits: Soft Limit: 8192 | Hard Limit: 1048576" in caplog.text + assert "Setting file descriptor soft limit to 20480" in caplog.text + + setrlimit_mock.assert_called_once_with(resource.RLIMIT_NOFILE, (20480, 1048576)) # pylint: disable=possibly-used-before-assignment diff --git a/tests/units/test_tools.py b/tests/units/test_tools.py index 396b0c840..9995bf84c 100644 --- a/tests/units/test_tools.py +++ b/tests/units/test_tools.py @@ -5,13 +5,20 @@ from __future__ import annotations +import asyncio from contextlib import AbstractContextManager from contextlib import nullcontext as does_not_raise -from typing import Any +from typing import TYPE_CHECKING, Any +from unittest.mock import Mock import pytest -from anta.tools import convert_categories, custom_division, format_data, get_dict_superset, get_failed_logs, get_item, get_value +# Import as Result to avoid PytestCollectionWarning +from anta.result_manager.models import TestResult as Result +from anta.tools import convert_categories, custom_division, format_data, get_dict_superset, get_failed_logs, get_item, get_value, limit_concurrency + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, AsyncIterator, Coroutine, Sequence TEST_GET_FAILED_LOGS_DATA = [ {"id": 1, "name": "Alice", "age": 30, "email": "alice@example.com"}, @@ -527,3 +534,78 @@ def test_convert_categories(test_input: list[str], expected_raise: AbstractConte def test_format_data(input_data: dict[str, bool], expected_output: str) -> None: """Test format_data.""" assert format_data(input_data) == expected_output + + +class TestLimitConcurrency: + """Test limit_concurrency function.""" + + # Helper classes and functions for testing limit_concurrency function + class _EmptyGenerator: + """Helper class to create an empty async generator.""" + + def __aiter__(self) -> AsyncIterator[Coroutine[Any, Any, Any]]: + """Make this class an async iterator.""" + return self + + async def __anext__(self) -> Coroutine[Any, Any, Any]: + """Raise StopAsyncIteration.""" + raise StopAsyncIteration + + async def _mock_test_coro(self, result: Result) -> Result: + """Mock coroutine simulating a test.""" + # Simulate some work + await asyncio.sleep(0.1) + return result + + async def _create_test_generator(self, results: Sequence[Result]) -> AsyncGenerator[Coroutine[Any, Any, Result], None]: + """Create a test generator yielding mock test coroutines.""" + for result in results: + yield self._mock_test_coro(result) + + # Unit tests + async def test_limit_concurrency_with_zero_limit(self) -> None: + """Test that limit_concurrency raises RuntimeError when limit is 0.""" + mock_result = Mock(spec=Result) + generator = self._create_test_generator([mock_result]) + + with pytest.raises(RuntimeError, match="Concurrency limit must be greater than 0"): + await limit_concurrency(generator, limit=0).__anext__() # pylint: disable=unnecessary-dunder-call + + async def test_limit_concurrency_with_negative_limit(self) -> None: + """Test that limit_concurrency raises RuntimeError when limit is negative.""" + mock_result = Mock(spec=Result) + generator = self._create_test_generator([mock_result]) + + with pytest.raises(RuntimeError, match="Concurrency limit must be greater than 0"): + await limit_concurrency(generator, limit=-1).__anext__() # pylint: disable=unnecessary-dunder-call + + async def test_limit_concurrency_with_empty_generator(self) -> None: + """Test limit_concurrency behavior with an empty generator.""" + results = [await result async for result in limit_concurrency(self._EmptyGenerator(), limit=10)] + assert len(results) == 0 + + async def test_limit_concurrency_with_concurrent_limit(self) -> None: + """Test limit_concurrency enforces the maximum number of concurrently running tasks.""" + max_concurrent = 0 + current_concurrent = 0 + + async def instrumented_coro(result: int) -> int: + nonlocal current_concurrent, max_concurrent + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + # Simulate work + await asyncio.sleep(0.1) + current_concurrent -= 1 + return result + + async def test_generator() -> AsyncGenerator[Coroutine[Any, Any, int], None]: + for i in range(10): + yield instrumented_coro(i) + + # Run with limit of 3 to test concurrency limit + completed_results = [await task async for task in limit_concurrency(test_generator(), limit=3)] + + # Verify all results were returned + assert len(completed_results) == 10 + # Verify that the maximum number of concurrently running tasks never exceeded the limit + assert max_concurrent <= 3