Skip to content

Commit

Permalink
refactor: Extract regular expression to support regexp unit tests (#686)
Browse files Browse the repository at this point in the history
* refactor: Extract regular expression to support regexp unit tests

* chore: Implement security fixes (#683)

* chore: Update Dockerfile to use non-root user

* fix(anta): Update regexp syntax for better readability

* Update Dockerfile

* Update Dockerfile

* Update Dockerfile

* Update Dockerfile

* fix(anta): Update regexp syntax for better readability

---------

Co-authored-by: Matthieu Tâche <mtache@arista.com>

* refactor: pytest for REGEXP_EOS_BLACKLIST_CMDS

---------

Co-authored-by: Matthieu Tâche <mtache@arista.com>
  • Loading branch information
titom73 and mtache authored May 17, 2024
1 parent 1c04244 commit cf595fd
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 17 deletions.
3 changes: 2 additions & 1 deletion anta/cli/exec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from click.exceptions import UsageError
from httpx import ConnectError, HTTPError

from anta.custom_types import REGEXP_PATH_MARKERS
from anta.device import AntaDevice, AsyncEOSDevice
from anta.models import AntaCommand

Expand Down Expand Up @@ -60,7 +61,7 @@ async def collect_commands(
async def collect(dev: AntaDevice, command: str, outformat: Literal["json", "text"]) -> None:
outdir = Path() / root_dir / dev.name / outformat
outdir.mkdir(parents=True, exist_ok=True)
safe_command = re.sub(r"[\\\/\s]", "_", command)
safe_command = re.sub(rf"{REGEXP_PATH_MARKERS}", "_", command)
c = AntaCommand(command=command, ofmt=outformat)
await dev.collect(c)
if not c.collected:
Expand Down
43 changes: 34 additions & 9 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
from pydantic import Field
from pydantic.functional_validators import AfterValidator, BeforeValidator

# Regular Expression definition
# TODO: make this configurable - with an env var maybe?
REGEXP_EOS_BLACKLIST_CMDS = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"]
"""List of regular expressions to blacklist from eos commands."""
REGEXP_PATH_MARKERS = r"[\\\/\s]"
"""Match directory path from string."""
REGEXP_INTERFACE_ID = r"\d+(\/\d+)*(\.\d+)?"
"""Match Interface ID lilke 1/1.1."""
REGEXP_TYPE_EOS_INTERFACE = r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$"
"""Match EOS interface types like Ethernet1/1, Vlan1, Loopback1, etc."""
REGEXP_TYPE_VXLAN_SRC_INTERFACE = r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$"
"""Match Vxlan source interface like Loopback10."""
REGEXP_TYPE_HOSTNAME = r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$"
"""Match hostname like `my-hostname`, `my-hostname-1`, `my-hostname-1-2`."""

# Regexp BGP AFI/SAFI
REGEXP_BGP_L2VPN_AFI = r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b"
"""Match L2VPN EVPN AFI."""
REGEXP_BGP_IPV4_MPLS_LABELS = r"\b(ipv4[\s\-]?mpls[\s\-]?label(s)?)\b"
"""Match IPv4 MPLS Labels."""
REGEX_BGP_IPV4_MPLS_VPN = r"\b(ipv4[\s\-]?mpls[\s\-]?vpn)\b"
"""Match IPv4 MPLS VPN."""
REGEX_BGP_IPV4_UNICAST = r"\b(ipv4[\s\-]?uni[\s\-]?cast)\b"
"""Match IPv4 Unicast."""


def aaa_group_prefix(v: str) -> str:
"""Prefix the AAA method with 'group' if it is known."""
Expand All @@ -24,7 +49,7 @@ def interface_autocomplete(v: str) -> str:
- `po` will be changed to `Port-Channel`
- `lo` will be changed to `Loopback`
"""
intf_id_re = re.compile(r"\d+(\/\d+)*(\.\d+)?")
intf_id_re = re.compile(REGEXP_INTERFACE_ID)
m = intf_id_re.search(v)
if m is None:
msg = f"Could not parse interface ID in interface '{v}'"
Expand All @@ -46,7 +71,7 @@ def interface_case_sensitivity(v: str) -> str:
- loopback -> Loopback
"""
if isinstance(v, str) and len(v) > 0 and not v[0].isupper():
if isinstance(v, str) and v != "" and not v[0].isupper():
return f"{v[0].upper()}{v[1:]}"
return v

Expand All @@ -63,10 +88,10 @@ def bgp_multiprotocol_capabilities_abbreviations(value: str) -> str:
"""
patterns = {
r"\b(l2[\s\-]?vpn[\s\-]?evpn)\b": "l2VpnEvpn",
r"\bipv4[\s_-]?mpls[\s_-]?label(s)?\b": "ipv4MplsLabels",
r"\bipv4[\s_-]?mpls[\s_-]?vpn\b": "ipv4MplsVpn",
r"\bipv4[\s_-]?uni[\s_-]?cast\b": "ipv4Unicast",
REGEXP_BGP_L2VPN_AFI: "l2VpnEvpn",
REGEXP_BGP_IPV4_MPLS_LABELS: "ipv4MplsLabels",
REGEX_BGP_IPV4_MPLS_VPN: "ipv4MplsVpn",
REGEX_BGP_IPV4_UNICAST: "ipv4Unicast",
}

for pattern, replacement in patterns.items():
Expand Down Expand Up @@ -97,7 +122,7 @@ def validate_regex(value: str) -> str:
Vni = Annotated[int, Field(ge=1, le=16777215)]
Interface = Annotated[
str,
Field(pattern=r"^(Dps|Ethernet|Fabric|Loopback|Management|Port-Channel|Tunnel|Vlan|Vxlan)[0-9]+(\/[0-9]+)*(\.[0-9]+)?$"),
Field(pattern=REGEXP_TYPE_EOS_INTERFACE),
BeforeValidator(interface_autocomplete),
BeforeValidator(interface_case_sensitivity),
]
Expand All @@ -109,7 +134,7 @@ def validate_regex(value: str) -> str:
]
VxlanSrcIntf = Annotated[
str,
Field(pattern=r"^(Loopback)([0-9]|[1-9][0-9]{1,2}|[1-7][0-9]{3}|8[01][0-9]{2}|819[01])$"),
Field(pattern=REGEXP_TYPE_VXLAN_SRC_INTERFACE),
BeforeValidator(interface_autocomplete),
BeforeValidator(interface_case_sensitivity),
]
Expand Down Expand Up @@ -139,6 +164,6 @@ def validate_regex(value: str) -> str:
Percent = Annotated[float, Field(ge=0.0, le=100.0)]
PositiveInteger = Annotated[int, Field(ge=0)]
Revision = Annotated[int, Field(ge=1, le=99)]
Hostname = Annotated[str, Field(pattern=r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$")]
Hostname = Annotated[str, Field(pattern=REGEXP_TYPE_HOSTNAME)]
Port = Annotated[int, Field(ge=1, le=65535)]
RegexString = Annotated[str, AfterValidator(validate_regex)]
9 changes: 3 additions & 6 deletions anta/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel, ConfigDict, ValidationError, create_model

from anta import GITHUB_SUGGESTION
from anta.custom_types import Revision
from anta.custom_types import REGEXP_EOS_BLACKLIST_CMDS, Revision
from anta.logger import anta_log_exception, exc_to_str
from anta.result_manager.models import TestResult

Expand All @@ -32,9 +32,6 @@
# This would imply overhead to define classes
# https://stackoverflow.com/questions/74103528/type-hinting-an-instance-of-a-nested-class

# TODO: make this configurable - with an env var maybe?
BLACKLIST_REGEX = [r"^reload.*", r"^conf\w*\s*(terminal|session)*", r"^wr\w*\s*\w+"]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -515,12 +512,12 @@ def blocked(self) -> bool:
"""Check if CLI commands contain a blocked keyword."""
state = False
for command in self.instance_commands:
for pattern in BLACKLIST_REGEX:
for pattern in REGEXP_EOS_BLACKLIST_CMDS:
if re.match(pattern, command.command):
self.logger.error(
"Command <%s> is blocked for security reason matching %s",
command.command,
BLACKLIST_REGEX,
REGEXP_EOS_BLACKLIST_CMDS,
)
self.result.is_error(f"<{command.command}> is blocked for security reason")
state = True
Expand Down
230 changes: 229 additions & 1 deletion tests/units/test_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,203 @@

from __future__ import annotations

import re

import pytest

from anta.custom_types import bgp_multiprotocol_capabilities_abbreviations, interface_autocomplete
from anta.custom_types import (
REGEX_BGP_IPV4_MPLS_VPN,
REGEX_BGP_IPV4_UNICAST,
REGEXP_BGP_IPV4_MPLS_LABELS,
REGEXP_BGP_L2VPN_AFI,
REGEXP_EOS_BLACKLIST_CMDS,
REGEXP_INTERFACE_ID,
REGEXP_PATH_MARKERS,
REGEXP_TYPE_EOS_INTERFACE,
REGEXP_TYPE_HOSTNAME,
REGEXP_TYPE_VXLAN_SRC_INTERFACE,
aaa_group_prefix,
bgp_multiprotocol_capabilities_abbreviations,
interface_autocomplete,
interface_case_sensitivity,
)

# ------------------------------------------------------------------------------
# TEST custom_types.py regular expressions
# ------------------------------------------------------------------------------


def test_regexp_path_markers() -> None:
"""Test REGEXP_PATH_MARKERS."""
# Test strings that should match the pattern
assert re.search(REGEXP_PATH_MARKERS, "show/bgp/interfaces") is not None
assert re.search(REGEXP_PATH_MARKERS, "show\\bgp") is not None
assert re.search(REGEXP_PATH_MARKERS, "show bgp") is not None

# Test strings that should not match the pattern
assert re.search(REGEXP_PATH_MARKERS, "aaaa") is None
assert re.search(REGEXP_PATH_MARKERS, "11111") is None
assert re.search(REGEXP_PATH_MARKERS, ".[]?<>") is None


def test_regexp_bgp_l2vpn_afi() -> None:
"""Test REGEXP_BGP_L2VPN_AFI."""
# Test strings that should match the pattern
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2-vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn evpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpnevpn") is not None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2 vpnevpn") is not None

# Test strings that should not match the pattern
assert re.search(REGEXP_BGP_L2VPN_AFI, "al2vpn evpn") is None
assert re.search(REGEXP_BGP_L2VPN_AFI, "l2vpn-evpna") is None


def test_regexp_bgp_ipv4_mpls_labels() -> None:
"""Test REGEXP_BGP_IPV4_MPLS_LABELS."""
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4-mpls-label") is not None
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4 mpls labels") is not None
assert re.search(REGEXP_BGP_IPV4_MPLS_LABELS, "ipv4Mplslabel") is None


def test_regex_bgp_ipv4_mpls_vpn() -> None:
"""Test REGEX_BGP_IPV4_MPLS_VPN."""
assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4-mpls-vpn") is not None
assert re.search(REGEX_BGP_IPV4_MPLS_VPN, "ipv4_mplsvpn") is None


def test_regex_bgp_ipv4_unicast() -> None:
"""Test REGEX_BGP_IPV4_UNICAST."""
assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4-uni-cast") is not None
assert re.search(REGEX_BGP_IPV4_UNICAST, "ipv4+unicast") is None


def test_regexp_type_interface_id() -> None:
"""Test REGEXP_INTERFACE_ID."""
intf_id_re = re.compile(f"{REGEXP_INTERFACE_ID}")

# Test strings that should match the pattern
assert intf_id_re.search("123") is not None
assert intf_id_re.search("123/456") is not None
assert intf_id_re.search("123.456") is not None
assert intf_id_re.search("123/456.789") is not None


def test_regexp_type_eos_interface() -> None:
"""Test REGEXP_TYPE_EOS_INTERFACE."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan100") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel1/0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback0.1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management0/0/0") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric1") is not None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps1") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vlan") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Tunnel") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Vxlan") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Fabric") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Dps") is None

assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Ethernet1/a") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Port-Channel-100") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Loopback.10") is None
assert re.match(REGEXP_TYPE_EOS_INTERFACE, "Management/10") is None


def test_regexp_type_vxlan_src_interface() -> None:
"""Test REGEXP_TYPE_VXLAN_SRC_INTERFACE."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback0") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback1") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback99") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback100") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8190") is not None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback8199") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback") is None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9001") is None
assert re.match(REGEXP_TYPE_VXLAN_SRC_INTERFACE, "Loopback9000") is None


def test_regexp_type_hostname() -> None:
"""Test REGEXP_TYPE_HOSTNAME."""
# Test strings that should match the pattern
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host-name.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host.name.com") is not None
assert re.match(REGEXP_TYPE_HOSTNAME, "host-name1.com") is not None

# Test strings that should not match the pattern
assert re.match(REGEXP_TYPE_HOSTNAME, "-hostname.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, ".hostname.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname-.com") is None
assert re.match(REGEXP_TYPE_HOSTNAME, "hostname..com") is None


@pytest.mark.parametrize(
("test_string", "expected"),
[
("reload", True), # matches "^reload.*"
("reload now", True), # matches "^reload.*"
("configure terminal", True), # matches "^conf\w*\s*(terminal|session)*"
("conf t", True), # matches "^conf\w*\s*(terminal|session)*"
("write memory", True), # matches "^wr\w*\s*\w+"
("wr mem", True), # matches "^wr\w*\s*\w+"
("show running-config", False), # does not match any regex
("no shutdown", False), # does not match any regex
("", False), # empty string does not match any regex
],
)
def test_regexp_eos_blacklist_cmds(test_string: str, expected: bool) -> None:
"""Test REGEXP_EOS_BLACKLIST_CMDS."""

def matches_any_regex(string: str, regex_list: list[str]) -> bool:
"""
Check if a string matches at least one regular expression in a list.
:param string: The string to check.
:param regex_list: A list of regular expressions.
:return: True if the string matches at least one regular expression, False otherwise.
"""
return any(re.match(regex, string) for regex in regex_list)

assert matches_any_regex(test_string, REGEXP_EOS_BLACKLIST_CMDS) == expected


# ------------------------------------------------------------------------------
# TEST custom_types.py functions
# ------------------------------------------------------------------------------


def test_interface_autocomplete_success() -> None:
"""Test interface_autocomplete with valid inputs."""
assert interface_autocomplete("et1") == "Ethernet1"
assert interface_autocomplete("et1/1") == "Ethernet1/1"
assert interface_autocomplete("et1.1") == "Ethernet1.1"
assert interface_autocomplete("et1/1.1") == "Ethernet1/1.1"
assert interface_autocomplete("eth2") == "Ethernet2"
assert interface_autocomplete("po3") == "Port-Channel3"
assert interface_autocomplete("lo4") == "Loopback4"


def test_interface_autocomplete_no_alias() -> None:
"""Test interface_autocomplete with inputs that don't have aliases."""
assert interface_autocomplete("GigabitEthernet1") == "GigabitEthernet1"
assert interface_autocomplete("Vlan10") == "Vlan10"
assert interface_autocomplete("Tunnel100") == "Tunnel100"


def test_interface_autocomplete_failure() -> None:
Expand All @@ -34,3 +228,37 @@ def test_interface_autocomplete_failure() -> None:
def test_bgp_multiprotocol_capabilities_abbreviationsh(str_input: str, expected_output: str) -> None:
"""Test bgp_multiprotocol_capabilities_abbreviations."""
assert bgp_multiprotocol_capabilities_abbreviations(str_input) == expected_output


def test_aaa_group_prefix_known_method() -> None:
"""Test aaa_group_prefix with a known method."""
assert aaa_group_prefix("local") == "local"
assert aaa_group_prefix("none") == "none"
assert aaa_group_prefix("logging") == "logging"


def test_aaa_group_prefix_unknown_method() -> None:
"""Test aaa_group_prefix with an unknown method."""
assert aaa_group_prefix("demo") == "group demo"
assert aaa_group_prefix("group1") == "group group1"


def test_interface_case_sensitivity_lowercase() -> None:
"""Test interface_case_sensitivity with lowercase inputs."""
assert interface_case_sensitivity("ethernet") == "Ethernet"
assert interface_case_sensitivity("vlan") == "Vlan"
assert interface_case_sensitivity("loopback") == "Loopback"


def test_interface_case_sensitivity_mixed_case() -> None:
"""Test interface_case_sensitivity with mixed case inputs."""
assert interface_case_sensitivity("Ethernet") == "Ethernet"
assert interface_case_sensitivity("Vlan") == "Vlan"
assert interface_case_sensitivity("Loopback") == "Loopback"


def test_interface_case_sensitivity_uppercase() -> None:
"""Test interface_case_sensitivity with uppercase inputs."""
assert interface_case_sensitivity("ETHERNET") == "ETHERNET"
assert interface_case_sensitivity("VLAN") == "VLAN"
assert interface_case_sensitivity("LOOPBACK") == "LOOPBACK"

0 comments on commit cf595fd

Please sign in to comment.