Skip to content

Commit

Permalink
fix typing on search/match/findall/sub/... (#5)
Browse files Browse the repository at this point in the history

Co-authored-by: Donald Nguyen <ddn0@users.noreply.github.com>
  • Loading branch information
trim21 and ddn0 authored Oct 17, 2024
1 parent bf49b50 commit 8047d12
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 308 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ jobs:
- run: ruff check --output-format=github .
- run: mypy
- uses: jakebailey/pyright-action@v2
- run: pytest
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.venv/
.idea/
338 changes: 205 additions & 133 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python = "^3.8.0"
pyright = "^1.1.365"
mypy = "^1.10.0"
ruff = "^0.4.7"
pytest = "^8.3.3"

[tool.ruff]
force-exclude = true
Expand Down Expand Up @@ -152,3 +153,9 @@ reportUnnecessaryTypeIgnoreComment = "error"
[tool.mypy]
strict = true
packages = "re2-stubs"

[tool.pytest.ini_options]
xfail_strict = true
testpaths = [
"tests"
]
248 changes: 74 additions & 174 deletions src/re2-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from typing import AnyStr, Generic, Iterator, Literal, TypeVar, overload
from typing import (
AnyStr,
Callable,
Generic,
Iterator,
Literal,
TypeVar,
overload,
)

from typing_extensions import TypeAlias

# References:
# - https://github.com/google/re2/blob/main/re2/re2.h and
Expand All @@ -7,6 +17,8 @@ from typing import AnyStr, Generic, Iterator, Literal, TypeVar, overload

_T = TypeVar("_T")

_Pattern: TypeAlias = _Regexp[str] | str | _Regexp[bytes] | bytes

class error(Exception): ...

class Options:
Expand All @@ -26,170 +38,80 @@ class Options:
def compile(
pattern: _Regexp[AnyStr] | AnyStr, options: Options | None = None
) -> _Regexp[AnyStr]: ...
@overload
def search(
pattern: _Regexp[str] | str, text: str, options: Options | None = None
) -> _Match[str] | None: ...
@overload
def search(
pattern: _Regexp[bytes] | bytes, text: bytes, options: Options | None = None
) -> _Match[bytes] | None: ...
@overload
def match(
pattern: _Regexp[str] | str, text: str, options: Options | None = None
) -> _Match[str] | None: ...
@overload
pattern: _Pattern, text: AnyStr, options: Options | None = None
) -> _Match[AnyStr] | None: ...
def match(
pattern: _Regexp[bytes] | bytes, text: bytes, options: Options | None = None
) -> _Match[bytes] | None: ...
@overload
pattern: _Pattern, text: AnyStr, options: Options | None = None
) -> _Match[AnyStr] | None: ...
def fullmatch(
pattern: _Regexp[str] | str, text: str, options: Options | None = None
) -> _Match[str] | None: ...
@overload
def fullmatch(
pattern: _Regexp[bytes] | bytes, text: bytes, options: Options | None = None
) -> _Match[bytes] | None: ...
@overload
def finditer(
pattern: _Regexp[str] | str, text: str, options: Options | None = None
) -> Iterator[_Match[str]]: ...
@overload
pattern: _Pattern, text: AnyStr, options: Options | None = None
) -> _Match[AnyStr] | None: ...
def finditer(
pattern: _Regexp[bytes] | bytes, text: bytes, options: Options | None = None
) -> Iterator[_Match[bytes]]: ...
@overload
def findall(
pattern: _Regexp[str] | str, text: str, options: Options | None = None
) -> list[str]: ...
@overload
pattern: _Pattern, text: AnyStr, options: Options | None = None
) -> Iterator[_Match[AnyStr]]: ...
def findall(
pattern: _Regexp[bytes] | bytes, text: bytes, options: Options | None = None
) -> list[bytes]: ...
@overload
pattern: _Pattern, text: AnyStr, options: Options | None = None
) -> list[AnyStr]: ...
def split(
pattern: _Regexp[str] | str,
text: str,
pattern: _Pattern,
text: AnyStr,
maxsplit: int = 0,
options: Options | None = None,
) -> list[str]: ...
@overload
def split(
pattern: _Regexp[bytes] | bytes,
text: bytes,
maxsplit: int = 0,
options: Options | None = None,
) -> list[bytes]: ...
@overload
) -> list[AnyStr]: ...
def subn(
pattern: _Regexp[str] | str,
repl: str,
text: str,
pattern: _Pattern,
repl: AnyStr | Callable[[_Match[AnyStr]], AnyStr],
text: AnyStr,
count: int = 0,
options: Options | None = None,
) -> tuple[str, int]: ...
@overload
def subn(
pattern: _Regexp[bytes] | bytes,
repl: bytes,
text: bytes,
count: int = 0,
options: Options | None = None,
) -> tuple[bytes, int]: ...
@overload
) -> tuple[AnyStr, int]: ...
def sub(
pattern: _Regexp[str] | str,
repl: str,
text: str,
pattern: _Pattern,
repl: AnyStr | Callable[[_Match[AnyStr]], AnyStr],
text: AnyStr,
count: int = 0,
options: Options | None = None,
) -> str: ...
@overload
def sub(
pattern: _Regexp[bytes] | bytes,
repl: bytes,
text: bytes,
count: int = 0,
options: Options | None = None,
) -> bytes: ...
) -> AnyStr: ...
def escape(pattern: AnyStr) -> AnyStr: ...
def purge() -> None: ...

# re2 regexps produce match objects based on the text to match regardless
# of the initial pattern the regexp was constructed with. Introduce a separately
# constrained AnyStr which is uncorrelated with the string type the regexp
# was originally constructed with to represent this re2 feature.
_AnyStr2 = TypeVar("_AnyStr2", str, bytes)

class _Regexp(Generic[AnyStr]):
def __init__(self, pattern: AnyStr, options: Options) -> None: ...
@overload
def search(
self: _Regexp[str], text: str, pos: int | None = None, endpos: int | None = None
) -> _Match[str] | None: ...
@overload
def search(
self: _Regexp[bytes],
text: bytes,
pos: int | None = None,
endpos: int | None = None,
) -> _Match[bytes] | None: ...
@overload
self, text: _AnyStr2, pos: int | None = None, endpos: int | None = None
) -> _Match[_AnyStr2] | None: ...
def match(
self: _Regexp[str], text: str, pos: int | None = None, endpos: int | None = None
) -> _Match[str] | None: ...
@overload
def match(
self: _Regexp[bytes],
text: bytes,
pos: int | None = None,
endpos: int | None = None,
) -> _Match[bytes] | None: ...
@overload
self, text: _AnyStr2, pos: int | None = None, endpos: int | None = None
) -> _Match[_AnyStr2] | None: ...
def fullmatch(
self: _Regexp[str], text: str, pos: int | None = None, endpos: int | None = None
) -> _Match[str] | None: ...
@overload
def fullmatch(
self: _Regexp[bytes],
text: bytes,
pos: int | None = None,
endpos: int | None = None,
) -> _Match[bytes] | None: ...
@overload
self, text: _AnyStr2, pos: int | None = None, endpos: int | None = None
) -> _Match[_AnyStr2] | None: ...
def finditer(
self: _Regexp[str], text: str, pos: int | None = None, endpos: int | None = None
) -> Iterator[_Match[str]]: ...
@overload
def finditer(
self: _Regexp[bytes],
text: bytes,
pos: int | None = None,
endpos: int | None = None,
) -> Iterator[_Match[bytes]]: ...
@overload
self, text: _AnyStr2, pos: int | None = None, endpos: int | None = None
) -> Iterator[_Match[_AnyStr2]]: ...
def findall(
self: _Regexp[str], text: str, pos: int | None = None, endpos: int | None = None
) -> list[str]: ...
@overload
def findall(
self: _Regexp[bytes],
text: bytes,
pos: int | None = None,
endpos: int | None = None,
) -> list[bytes]: ...
@overload
def split(self: _Regexp[str], text: str, maxsplit: int = 0) -> list[str]: ...
@overload
def split(self: _Regexp[bytes], text: bytes, maxsplit: int = 0) -> list[bytes]: ...
@overload
def subn(
self: _Regexp[str], repl: str, text: str, count: int = 0
) -> tuple[str, int]: ...
@overload
self, text: _AnyStr2, pos: int | None = None, endpos: int | None = None
) -> list[_AnyStr2]: ...
def split(self, text: _AnyStr2, maxsplit: int = 0) -> list[_AnyStr2]: ...
def subn(
self: _Regexp[bytes], repl: bytes, text: bytes, count: int = 0
) -> tuple[bytes, int]: ...
@overload
def sub(self: _Regexp[str], repl: str, text: str, count: int = 0) -> str: ...
@overload
self,
repl: _AnyStr2 | Callable[[_Match[_AnyStr2]], _AnyStr2],
text: _AnyStr2,
count: int = 0,
) -> tuple[_AnyStr2, int]: ...
def sub(
self: _Regexp[bytes], repl: bytes, text: bytes, count: int = 0
) -> bytes: ...
self,
repl: _AnyStr2 | Callable[[_Match[_AnyStr2]], _AnyStr2],
text: _AnyStr2,
count: int = 0,
) -> _AnyStr2: ...
@property
def pattern(self) -> AnyStr: ...
@property
Expand Down Expand Up @@ -217,50 +139,28 @@ class _Match(Generic[AnyStr]):
endpos: int,
spans: dict[int, tuple[int, int]],
) -> None: ...
def expand(self: _Match[AnyStr], template: str) -> AnyStr: ...
def __getitem__(self, group: int | str) -> AnyStr | None: ...
@overload
def expand(self: _Match[str], template: str) -> str: ...
@overload
def expand(self: _Match[bytes], template: bytes) -> bytes: ...
@overload
def __getitem__(self: _Match[str], group: int | str) -> str | None: ...
@overload
def __getitem__(self: _Match[bytes], group: int | bytes) -> bytes | None: ...
@overload
def group(self: _Match[str], group: Literal[0] = 0, /) -> str: ...
@overload
def group(self: _Match[bytes], group: Literal[0] = 0, /) -> bytes: ...
@overload
def group(self: _Match[str], group: str | int, /) -> str: ...
def group(self, group: Literal[0] = 0, /) -> AnyStr: ...
@overload
def group(self: _Match[bytes], group: bytes | int, /) -> bytes: ...
def group(self, group: AnyStr | int, /) -> AnyStr: ...
@overload
def group(
self: _Match[str], group1: str | int, group2: str | int, /, *groups: str | int
) -> tuple[str, ...]: ...
@overload
def group(
self: _Match[bytes],
group1: bytes | int,
group2: bytes | int,
self: _Match[str],
group1: AnyStr | int,
group2: AnyStr | int,
/,
*groups: bytes | int,
) -> tuple[bytes, ...]: ...
@overload
def groups(self: _Match[str]) -> tuple[str, ...]: ...
@overload
def groups(self: _Match[bytes]) -> tuple[bytes, ...]: ...
@overload
def groups(self: _Match[str], default: _T) -> tuple[str | _T, ...]: ...
@overload
def groups(self: _Match[bytes], default: _T) -> tuple[bytes | _T, ...]: ...
*groups: AnyStr | int,
) -> tuple[AnyStr, ...]: ...
@overload
def groupdict(self: _Match[str]) -> dict[str, str]: ...
def groups(self) -> tuple[AnyStr, ...]: ...
@overload
def groupdict(self: _Match[bytes]) -> dict[bytes, bytes]: ...
def groups(self, default: _T) -> tuple[AnyStr | _T, ...]: ...
@overload
def groupdict(self: _Match[str], default: _T) -> dict[str, str | _T]: ...
def groupdict(self: _Match[AnyStr]) -> dict[AnyStr, AnyStr]: ...
@overload
def groupdict(self: _Match[bytes], default: _T) -> dict[bytes, bytes | _T]: ...
def groupdict(self, default: _T) -> dict[AnyStr, AnyStr | _T]: ...
def start(self, group: int = 0) -> int: ...
def end(self, group: int = 0) -> int: ...
def span(self, group: int = 0) -> tuple[int, int]: ...
Expand Down
2 changes: 1 addition & 1 deletion tests/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Test completeness of stub with examples of re2 use."""
"""Test types match interface."""
52 changes: 52 additions & 0 deletions tests/api/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Test completeness of stub with examples of re2 use."""

# pyright: reportPrivateUsage = false
from __future__ import annotations

from typing_extensions import assert_type

import re2


def test_typing_only_tests() -> None:
_1: re2._Match[str] | None = re2.search(b"", "")
_2: re2._Match[bytes] | None = re2.search(b"", b"")
_3: re2._Match[bytes] | None = re2.search("", b"")
_4: re2._Match[str] | None = re2.search("", "")

_5: re2._Match[str] | None = re2.compile(b"").search("")
_6: re2._Match[bytes] | None = re2.compile(b"").search(b"")
_7: re2._Match[str] | None = re2.compile("").search("")
_8: re2._Match[bytes] | None = re2.compile("").search(b"")

assert assert_type(re2.sub(b"", "", ""), str) == ""
assert assert_type(re2.sub("", b"", b""), bytes) == b""

def repl_str(m: re2._Match[str]) -> str:
return ""

def repl_bytes(m: re2._Match[bytes]) -> bytes:
return b""

assert assert_type(re2.compile(b"").sub("", ""), str) == ""
assert assert_type(re2.compile("").sub("", ""), str) == ""
assert assert_type(re2.compile(b"").sub(repl_str, ""), str) == ""
assert assert_type(re2.compile("").sub(b"", b""), bytes) == b""
assert assert_type(re2.compile("").sub(repl_bytes, b""), bytes) == b""


def test_runtime_types() -> None:
assert re2.sub(b"", "", "") == ""
assert re2.sub("", b"", b"") == b""

def repl_str(m: re2._Match[str]) -> str:
return ""

def repl_bytes(m: re2._Match[bytes]) -> bytes:
return b""

assert re2.compile(b"").sub("", "") == ""
assert re2.compile("").sub("", "") == ""
assert re2.compile(b"").sub(repl_str, "") == ""
assert re2.compile("").sub(b"", b"") == b""
assert re2.compile("").sub(repl_bytes, b"") == b""

0 comments on commit 8047d12

Please sign in to comment.