diff --git a/approvaltests/namer/stack_frame_namer.py b/approvaltests/namer/stack_frame_namer.py index 2eea7618..6359550d 100644 --- a/approvaltests/namer/stack_frame_namer.py +++ b/approvaltests/namer/stack_frame_namer.py @@ -2,10 +2,12 @@ import os from inspect import FrameInfo from typing import Optional, Dict, List +import fnmatch from approvaltests.namer.namer_base import NamerBase from approvaltests.approval_exception import FrameNotFound from approval_utilities.utilities.stack_frame_utilities import get_class_name_for_frame +from approvaltests.pytest.pytest_config import PytestConfig class StackFrameNamer(NamerBase): @@ -40,7 +42,26 @@ def get_test_frame_index(caller: List[FrameInfo]) -> int: raise FrameNotFound(message) @staticmethod - def is_test_method(frame: FrameInfo) -> bool: + def is_pytest_test(frame: FrameInfo) -> bool: + method_name = frame[3] + patterns = PytestConfig.test_naming_patterns + + # taken from pytest/python.py (class PyCollector) + for pattern in patterns: + if method_name.startswith(pattern): + return True + # Check that name looks like a glob-string before calling fnmatch + # because this is called for every name in each collected module, + # and fnmatch is somewhat expensive to call. + elif ( + "*" in pattern or "?" in pattern or "[" in pattern + ) and fnmatch.fnmatch(method_name, pattern): + return True + + return False + + @staticmethod + def is_unittest_test(frame: FrameInfo) -> bool: method_name = frame[3] local_attributes = frame[0].f_locals is_unittest_test = ( @@ -51,10 +72,13 @@ def is_test_method(frame: FrameInfo) -> bool: and method_name != "_callTestMethod" and method_name != "run" ) + return is_unittest_test - is_pytest_test = method_name.startswith("test_") - - return is_unittest_test or is_pytest_test + @staticmethod + def is_test_method(frame: FrameInfo) -> bool: + return StackFrameNamer.is_unittest_test( + frame + ) or StackFrameNamer.is_pytest_test(frame) def get_class_name(self) -> str: return self.class_name diff --git a/approvaltests/pytest/pytest_config.py b/approvaltests/pytest/pytest_config.py new file mode 100644 index 00000000..a2928d26 --- /dev/null +++ b/approvaltests/pytest/pytest_config.py @@ -0,0 +1,10 @@ +class PytestConfig: + test_naming_patterns = ["test_*"] + + @staticmethod + def set_config(config): + PytestConfig.test_naming_patterns = config.getini("python_functions") + + +def set_pytest_config(config): + PytestConfig.set_config(config) diff --git a/approvaltests/pytest/pytest_plugin.py b/approvaltests/pytest/pytest_plugin.py new file mode 100644 index 00000000..c6be31b3 --- /dev/null +++ b/approvaltests/pytest/pytest_plugin.py @@ -0,0 +1,4 @@ +from .pytest_config import set_pytest_config + +def pytest_configure(config): + set_pytest_config(config) diff --git a/setup_utils.py b/setup_utils.py index a585cc48..ed748670 100644 --- a/setup_utils.py +++ b/setup_utils.py @@ -31,6 +31,11 @@ def do_the_setup(package_name, package_description, required, extra_requires): python_requires=">=3.8", packages=find_packages(include=["approvaltests*"]), package_data={"approvaltests": ["reporters/reporters.json"]}, + entry_points={ + 'pytest11': [ + 'approvaltests_pytest = approvaltests.pytest.pytest_plugin', + ], + }, install_requires=required, extras_require=extra_requires, long_description=(get_parent_directory() / "README.md").read_text(),