diff --git a/anta/device.py b/anta/device.py index 3624fdb2e..587a799a1 100644 --- a/anta/device.py +++ b/anta/device.py @@ -143,6 +143,10 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo self.cache: AntaCache | None = None # Keeping cache_locks for backward compatibility. self.cache_locks: defaultdict[str, asyncio.Lock] | None = None + self.command_queue: asyncio.Queue[AntaCommand] = asyncio.Queue() + self.batch_task: asyncio.Task[None] | None = None + # TODO: Check if we want to make the batch size configurable + self.batch_size: int = 100 # Initialize cache if not disabled if not disable_cache: @@ -166,6 +170,12 @@ def _init_cache(self) -> None: self.cache = AntaCache(device=self.name, ttl=60) self.cache_locks = self.cache.locks + def init_batch_task(self) -> None: + """Initialize the batch task for the device.""" + if self.batch_task is None: + logger.debug("<%s>: Starting the batch task", self.name) + self.batch_task = asyncio.create_task(self._batch_task()) + @property def cache_statistics(self) -> dict[str, Any] | None: """Return the device cache statistics for logging purposes.""" @@ -198,6 +208,72 @@ def __repr__(self) -> str: f"disable_cache={self.cache is None!r})" ) + async def _batch_task(self) -> None: + """Background task to retrieve commands put by tests from the command queue of this device. + + Test coroutines put their AntaCommand instances in the queue, this task retrieves them. Once they stop coming, + the instances are grouped by UID, split into JSON and text batches, and collected in batches of `batch_size`. + """ + collection_tasks: list[asyncio.Task[None]] = [] + all_commands: list[AntaCommand] = [] + + while True: + try: + get_await = self.command_queue.get() + command = await asyncio.wait_for(get_await, timeout=0.5) + logger.debug("<%s>: Command retrieved from the queue: %s", self.name, command) + all_commands.append(command) + except asyncio.TimeoutError: # noqa: PERF203 + logger.debug("<%s>: All test commands have been retrieved from the queue", self.name) + break + + # Group all command instances by UID + command_groups: defaultdict[str, list[AntaCommand]] = defaultdict(list[AntaCommand]) + for command in all_commands: + command_groups[command.uid].append(command) + + # Split into JSON and text batches. We can safely take the first command instance from each UID as they are the same. + json_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "json"} + text_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "text"} + + # Process JSON batches + for i in range(0, len(json_commands), self.batch_size): + batch = dict(list(json_commands.items())[i : i + self.batch_size]) + task = asyncio.create_task(self._collect_batch(batch, ofmt="json")) + collection_tasks.append(task) + + # Process text batches + for i in range(0, len(text_commands), self.batch_size): + batch = dict(list(text_commands.items())[i : i + self.batch_size]) + task = asyncio.create_task(self._collect_batch(batch, ofmt="text")) + collection_tasks.append(task) + + # Wait for all collection tasks to complete + if collection_tasks: + logger.debug("<%s>: Waiting for %d collection tasks to complete", self.name, len(collection_tasks)) + await asyncio.gather(*collection_tasks) + + # TODO: Handle other exceptions + + logger.debug("<%s>: Stopping the batch task", self.name) + + async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None: + """Collect a batch of device commands. + + This coroutine must be implemented by subclasses that want to support command queuing + in conjunction with the `_batch_task()` method. + + Parameters + ---------- + command_groups + Mapping of command instances grouped by UID to avoid duplicate commands. + ofmt + The output format of the batch. + """ + _ = (command_groups, ofmt) + msg = f"_collect_batch method has not been implemented in {self.__class__.__name__} definition" + raise NotImplementedError(msg) + @abstractmethod async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: """Collect device command output. @@ -251,16 +327,38 @@ async def collect(self, command: AntaCommand, *, collection_id: str | None = Non else: await self._collect(command=command, collection_id=collection_id) - async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None: + async def collect_commands(self, commands: list[AntaCommand], *, command_queuing: bool = False, collection_id: str | None = None) -> None: """Collect multiple commands. Parameters ---------- commands The commands to collect. + command_queuing + If True, the commands are put in a queue and collected in batches. Default is False. collection_id - An identifier used to build the eAPI request ID. + An identifier used to build the eAPI request ID. Not used when command queuing is enabled. """ + # Collect the commands with queuing + if command_queuing: + # Disable cache for this device as it is not needed when using command queuing + self.cache = None + self.cache_locks = None + + # Initialize the device batch task if not already running + self.init_batch_task() + + # Put the commands in the queue + for command in commands: + logger.debug("<%s>: Putting command in the queue: %s", self.name, command) + await self.command_queue.put(command) + + # Wait for all commands to be collected. + logger.debug("<%s>: Waiting for all commands to be collected", self.name) + await asyncio.gather(*[command.event.wait() for command in commands]) + return + + # Collect the commands without queuing. Default behavior. await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands)) @abstractmethod @@ -431,6 +529,78 @@ def _keys(self) -> tuple[Any, ...]: """ return (self._session.host, self._session.port) + async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None: # noqa: C901 + """Collect a batch of device commands. + + Parameters + ---------- + command_groups + Mapping of command instances grouped by UID to avoid duplicate commands. + ofmt + The output format of the batch. + """ + # Add 'enable' command if required + cmds = [] + if self.enable and self._enable_password is not None: + cmds.append({"cmd": "enable", "input": str(self._enable_password)}) + elif self.enable: + # No password + cmds.append({"cmd": "enable"}) + + # Take first instance from each group for the actual commands + cmds.extend( + [ + {"cmd": instances[0].command, "revision": instances[0].revision} if instances[0].revision else {"cmd": instances[0].command} + for instances in command_groups.values() + ] + ) + + try: + response = await self._session.cli( + commands=cmds, + ofmt=ofmt, + # TODO: See if we want to have different batches for different versions + version=1, + # TODO: See if want to have a different req_id for each batch + req_id=f"ANTA-{id(command_groups)}", + ) + + # Do not keep response of 'enable' command + if self.enable: + response = response[1:] + + # Update all AntaCommand instances with their output and signal their completion + logger.debug("<%s>: Collected batch of commands, signaling their completion", self.name) + for idx, instances in enumerate(command_groups.values()): + output = response[idx] + for cmd_instance in instances: + cmd_instance.output = output + cmd_instance.event.set() + + except asynceapi.EapiCommandError as e: + # TODO: Handle commands that passed + for instances in command_groups.values(): + for cmd_instance in instances: + cmd_instance.errors = e.errors + if cmd_instance.requires_privileges: + logger.error( + "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.", + cmd_instance.command, + self.name, + ) + if cmd_instance.supported: + logger.error("Command '%s' failed on %s: %s", cmd_instance.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors) + else: + logger.debug("Command '%s' is not supported on '%s' (%s)", cmd_instance.command, self.name, self.hw_model) + cmd_instance.event.set() + + # TODO: Handle other exceptions + except Exception as e: + for instances in command_groups.values(): + for cmd_instance in instances: + cmd_instance.errors = [exc_to_str(e)] + cmd_instance.event.set() + async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: """Collect device command output from EOS using aio-eapi. diff --git a/anta/models.py b/anta/models.py index 1eb677849..542ca4569 100644 --- a/anta/models.py +++ b/anta/models.py @@ -5,11 +5,12 @@ from __future__ import annotations +import asyncio import hashlib import logging import re from abc import ABC, abstractmethod -from functools import wraps +from functools import cached_property, wraps from string import Formatter from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar @@ -165,7 +166,9 @@ class AntaCommand(BaseModel): Pydantic Model containing the variables values used to render the template. use_cache Enable or disable caching for this AntaCommand if the AntaDevice supports it. - + event + Event to signal that the command has been collected. Used by an AntaDevice to signal an AntaTest that the command has been collected. + Only relevant when an AntaTest runs with `command_queuing=True`. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -179,13 +182,13 @@ class AntaCommand(BaseModel): errors: list[str] = [] params: AntaParamsBaseModel = AntaParamsBaseModel() use_cache: bool = True + event: asyncio.Event | None = None - @property + @cached_property def uid(self) -> str: """Generate a unique identifier for this command.""" uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}" - # Ignoring S324 probable use of insecure hash function - sha1 is enough for our needs. - return hashlib.sha1(uid_str.encode()).hexdigest() # noqa: S324 + return hashlib.sha256(uid_str.encode()).hexdigest() @property def json_output(self) -> dict[str, Any]: @@ -431,6 +434,8 @@ def __init__( device: AntaDevice, inputs: dict[str, Any] | AntaTest.Input | None = None, eos_data: list[dict[Any, Any] | str] | None = None, + *, + command_queuing: bool = False, ) -> None: """Initialize an AntaTest instance. @@ -443,10 +448,14 @@ def __init__( eos_data Populate outputs of the test commands instead of collecting from devices. This list must have the same length and order than the `instance_commands` instance attribute. + command_queuing + If True, the commands of this test will be queued in the device command queue and be sent in batches. + Default is False, which means the commands will be sent one by one to the device. """ self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}") self.device: AntaDevice = device self.inputs: AntaTest.Input + self.command_queuing = command_queuing self.instance_commands: list[AntaCommand] = [] self.result: TestResult = TestResult( name=device.name, @@ -496,10 +505,17 @@ def _init_commands(self, eos_data: list[dict[Any, Any] | str] | None) -> None: if self.__class__.commands: for cmd in self.__class__.commands: if isinstance(cmd, AntaCommand): - self.instance_commands.append(cmd.model_copy()) + command = cmd.model_copy() + if self.command_queuing: + command.event = asyncio.Event() + self.instance_commands.append(command) elif isinstance(cmd, AntaTemplate): try: - self.instance_commands.extend(self.render(cmd)) + rendered_commands = self.render(cmd) + if self.command_queuing: + for command in rendered_commands: + command.event = asyncio.Event() + self.instance_commands.extend(rendered_commands) except AntaTemplateRenderError as e: self.result.is_error(message=f"Cannot render template {{{e.template}}}") return @@ -590,7 +606,7 @@ async def collect(self) -> None: """Collect outputs of all commands of this test class from the device of this test instance.""" try: if self.blocked is False: - await self.device.collect_commands(self.instance_commands, collection_id=self.name) + await self.device.collect_commands(self.instance_commands, collection_id=self.name, command_queuing=self.command_queuing) except Exception as e: # noqa: BLE001 # device._collect() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -615,7 +631,6 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]: async def wrapper( self: AntaTest, eos_data: list[dict[Any, Any] | str] | None = None, - **kwargs: dict[str, Any], ) -> TestResult: """Inner function for the anta_test decorator. @@ -657,7 +672,7 @@ async def wrapper( return self.result try: - function(self, **kwargs) + function(self) except Exception as e: # noqa: BLE001 # test() is user-defined code. # We need to catch everything if we want the AntaTest object @@ -699,7 +714,7 @@ def update_progress(cls: type[AntaTest]) -> None: cls.progress.update(cls.nrfu_task, advance=1) @abstractmethod - def test(self) -> Coroutine[Any, Any, TestResult]: + def test(self) -> None: """Core of the test logic. This is an abstractmethod that must be implemented by child classes. diff --git a/anta/runner.py b/anta/runner.py index 84e27a133..eec277f88 100644 --- a/anta/runner.py +++ b/anta/runner.py @@ -26,6 +26,9 @@ from anta.result_manager import ResultManager from anta.result_manager.models import TestResult +logger = logging.getLogger(__name__) +COMMAND_QUEUING = False + if os.name == "posix": import resource @@ -57,7 +60,14 @@ def adjust_rlimit_nofile() -> tuple[int, int]: return resource.getrlimit(resource.RLIMIT_NOFILE) -logger = logging.getLogger(__name__) +def get_command_queuing() -> bool: + """Return the command queuing flag from the environment variable if set.""" + try: + command_queuing = bool(os.environ.get("ANTA_COMMAND_QUEUING", COMMAND_QUEUING)) + except ValueError as exception: + logger.warning("The ANTA_COMMAND_QUEUING environment variable value is invalid: %s\nDefault to %s.", exc_to_str(exception), COMMAND_QUEUING) + command_queuing = COMMAND_QUEUING + return command_queuing def log_cache_statistics(devices: list[AntaDevice]) -> None: @@ -193,11 +203,12 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio list[Coroutine[Any, Any, TestResult]] The list of coroutines to run. """ + command_queuing = get_command_queuing() coros = [] for device, test_definitions in selected_tests.items(): for test in test_definitions: try: - test_instance = test.test(device=device, inputs=test.inputs) + test_instance = test.test(device=device, inputs=test.inputs, command_queuing=command_queuing) if manager is not None: manager.add(test_instance.result) coros.append(test_instance.test()) @@ -215,7 +226,7 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio @cprofile() -async def main( +async def main( # noqa: C901 manager: ResultManager, inventory: AntaInventory, catalog: AntaCatalog, @@ -308,4 +319,5 @@ async def main( for result in results: manager.add(result) - log_cache_statistics(selected_inventory.devices) + if not get_command_queuing(): + log_cache_statistics(selected_inventory.devices)