Skip to content

Commit

Permalink
refactor(anta): Move conversion of categories to reporters (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc authored Oct 10, 2024
1 parent 93a4b44 commit 42476d9
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 15 deletions.
3 changes: 2 additions & 1 deletion anta/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rich.table import Table

from anta import RICH_COLOR_PALETTE, RICH_COLOR_THEME
from anta.tools import convert_categories

if TYPE_CHECKING:
import pathlib
Expand Down Expand Up @@ -125,7 +126,7 @@ def report_all(self, manager: ResultManager, title: str = "All tests results") -
def add_line(result: TestResult) -> None:
state = self._color_result(result.result)
message = self._split_list_to_txt_list(result.messages) if len(result.messages) > 0 else ""
categories = ", ".join(result.categories)
categories = ", ".join(convert_categories(result.categories))
table.add_row(str(result.name), result.test, state, message, result.description, categories)

for result in manager.results:
Expand Down
3 changes: 2 additions & 1 deletion anta/reporter/csv_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING

from anta.logger import anta_log_exception
from anta.tools import convert_categories

if TYPE_CHECKING:
import pathlib
Expand Down Expand Up @@ -71,7 +72,7 @@ def convert_to_list(cls, result: TestResult) -> list[str]:
TestResult converted into a list.
"""
message = cls.split_list_to_txt_list(result.messages) if len(result.messages) > 0 else ""
categories = cls.split_list_to_txt_list(result.categories) if len(result.categories) > 0 else "None"
categories = cls.split_list_to_txt_list(convert_categories(result.categories)) if len(result.categories) > 0 else "None"
return [
str(result.name),
result.test,
Expand Down
7 changes: 4 additions & 3 deletions anta/reporter/md_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from anta.constants import MD_REPORT_TOC
from anta.logger import anta_log_exception
from anta.result_manager.models import AntaTestStatus
from anta.tools import convert_categories

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -238,8 +239,8 @@ def generate_rows(self) -> Generator[str, None, None]:
"""Generate the rows of the summary totals device under test table."""
for device, stat in self.results.device_stats.items():
total_tests = stat.tests_success_count + stat.tests_skipped_count + stat.tests_failure_count + stat.tests_error_count
categories_skipped = ", ".join(sorted(stat.categories_skipped))
categories_failed = ", ".join(sorted(stat.categories_failed))
categories_skipped = ", ".join(sorted(convert_categories(list(stat.categories_skipped))))
categories_failed = ", ".join(sorted(convert_categories(list(stat.categories_failed))))
yield (
f"| {device} | {total_tests} | {stat.tests_success_count} | {stat.tests_skipped_count} | {stat.tests_failure_count} | {stat.tests_error_count} "
f"| {categories_skipped or '-'} | {categories_failed or '-'} |\n"
Expand Down Expand Up @@ -286,7 +287,7 @@ def generate_rows(self) -> Generator[str, None, None]:
"""Generate the rows of the all test results table."""
for result in self.results.get_results(sort_by=["name", "test"]):
messages = self.safe_markdown(", ".join(result.messages))
categories = ", ".join(result.categories)
categories = ", ".join(convert_categories(result.categories))
yield (
f"| {result.name or '-'} | {categories or '-'} | {result.test or '-'} "
f"| {result.description or '-'} | {self.safe_markdown(result.custom_field) or '-'} | {result.result or '-'} | {messages or '-'} |\n"
Expand Down
4 changes: 0 additions & 4 deletions anta/result_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from functools import cached_property
from itertools import chain

from anta.constants import ACRONYM_CATEGORIES
from anta.result_manager.models import AntaTestStatus, TestResult

from .models import CategoryStats, DeviceStats, TestStats
Expand Down Expand Up @@ -162,9 +161,6 @@ def _update_stats(self, result: TestResult) -> None:
result
TestResult to update the statistics.
"""
result.categories = [
" ".join(word.upper() if word.lower() in ACRONYM_CATEGORIES else word.title() for word in category.split()) for category in result.categories
]
count_attr = f"tests_{result.result}_count"

# Update device stats
Expand Down
23 changes: 23 additions & 0 deletions anta/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

from anta.constants import ACRONYM_CATEGORIES
from anta.custom_types import REGEXP_PATH_MARKERS
from anta.logger import format_td

Expand Down Expand Up @@ -372,3 +373,25 @@ def safe_command(command: str) -> str:
The sanitized command.
"""
return re.sub(rf"{REGEXP_PATH_MARKERS}", "_", command)


def convert_categories(categories: list[str]) -> list[str]:
"""Convert categories for reports.
if the category is part of the defined acronym, transform it to upper case
otherwise capitalize the first letter.
Parameters
----------
categories
A list of categories
Returns
-------
list[str]
The list of converted categories
"""
if isinstance(categories, list):
return [" ".join(word.upper() if word.lower() in ACRONYM_CATEGORIES else word.title() for word in category.split()) for category in categories]
msg = f"Wrong input type '{type(categories)}' for convert_categories."
raise TypeError(msg)
3 changes: 2 additions & 1 deletion tests/units/reporter/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from anta.reporter.csv_reporter import ReportCsv
from anta.result_manager import ResultManager
from anta.tools import convert_categories


class TestReportCsv:
Expand All @@ -25,7 +26,7 @@ def compare_csv_and_result(self, rows: list[Any], index: int, result_manager: Re
assert rows[index + 1][2] == result_manager.results[index].result
assert rows[index + 1][3] == ReportCsv().split_list_to_txt_list(result_manager.results[index].messages)
assert rows[index + 1][4] == result_manager.results[index].description
assert rows[index + 1][5] == ReportCsv().split_list_to_txt_list(result_manager.results[index].categories)
assert rows[index + 1][5] == ReportCsv().split_list_to_txt_list(convert_categories(result_manager.results[index].categories))

def test_report_csv_generate(
self,
Expand Down
8 changes: 4 additions & 4 deletions tests/units/result_manager/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def test_sorted_category_stats(self, list_result_factory: Callable[[int], list[T

result_manager.results = results

# Check the current categories order and name format
expected_order = ["OSPF", "BGP", "VXLAN", "System"]
# Check the current categories order
expected_order = ["ospf", "bgp", "vxlan", "system"]
assert list(result_manager.category_stats.keys()) == expected_order

# Check the sorted categories order and name format
expected_order = ["BGP", "OSPF", "System", "VXLAN"]
# Check the sorted categories order
expected_order = ["bgp", "ospf", "system", "vxlan"]
assert list(result_manager.sorted_category_stats.keys()) == expected_order

@pytest.mark.parametrize(
Expand Down
16 changes: 15 additions & 1 deletion tests/units/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytest

from anta.tools import custom_division, get_dict_superset, get_failed_logs, get_item, get_value
from anta.tools import convert_categories, custom_division, get_dict_superset, get_failed_logs, get_item, get_value

TEST_GET_FAILED_LOGS_DATA = [
{"id": 1, "name": "Alice", "age": 30, "email": "alice@example.com"},
Expand Down Expand Up @@ -499,3 +499,17 @@ def test_get_item(
def test_custom_division(numerator: float, denominator: float, expected_result: str) -> None:
"""Test custom_division."""
assert custom_division(numerator, denominator) == expected_result


@pytest.mark.parametrize(
("test_input", "expected_raise", "expected_result"),
[
pytest.param([], does_not_raise(), [], id="empty list"),
pytest.param(["bgp", "system", "vlan", "configuration"], does_not_raise(), ["BGP", "System", "VLAN", "Configuration"], id="list with acronyms and titles"),
pytest.param(42, pytest.raises(TypeError, match="Wrong input type"), None, id="wrong input type"),
],
)
def test_convert_categories(test_input: list[str], expected_raise: AbstractContextManager[Exception], expected_result: list[str]) -> None:
"""Test convert_categories."""
with expected_raise:
assert convert_categories(test_input) == expected_result

0 comments on commit 42476d9

Please sign in to comment.