Skip to content

Commit

Permalink
✨ let func.py back
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 25, 2024
1 parent 6dd516e commit 08242af
Show file tree
Hide file tree
Showing 5 changed files with 464 additions and 25 deletions.
50 changes: 49 additions & 1 deletion nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
import re
import sys
from typing import Any, Callable, Final, ForwardRef, Generic, Match, TypeVar, Union, final
from typing import Any, Callable, Final, ForwardRef, Generic, Match, TypeVar, Union, final, overload

from tarina import DateParser, lang

Expand All @@ -17,6 +17,14 @@
TDefault = TypeVar("TDefault")
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_T4 = TypeVar("_T4")
_T5 = TypeVar("_T5")
_T6 = TypeVar("_T6")
_T7 = TypeVar("_T7")
_T8 = TypeVar("_T8")
_T9 = TypeVar("_T9")

_TP = TypeVar("_TP", bound=Pattern)

Expand Down Expand Up @@ -162,6 +170,46 @@ def of(cls, *types: type[_T1]) -> UnionPattern[_T1]:

return cls([parser(i) for i in types]) # type: ignore

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], /) -> UnionPattern[_T1 | _T2]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], /) -> UnionPattern[_T1 | _T2 | _T3]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], pat5: Pattern[_T5], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4 | _T5]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], pat5: Pattern[_T5], pat6: Pattern[_T6], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4 | _T5 | _T6]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], pat5: Pattern[_T5], pat6: Pattern[_T6], pat7: Pattern[_T7], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4 | _T5 | _T6 | _T7]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], pat5: Pattern[_T5], pat6: Pattern[_T6], pat7: Pattern[_T7], pat8: Pattern[_T8], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4 | _T5 | _T6 | _T7 | _T8]: ...

@classmethod
@overload
def with_(cls, pat1: Pattern[_T1], pat2: Pattern[_T2], pat3: Pattern[_T3], pat4: Pattern[_T4], pat5: Pattern[_T5], pat6: Pattern[_T6], pat7: Pattern[_T7], pat8: Pattern[_T8], pat9: Pattern[_T9], /) -> UnionPattern[_T1 | _T2 | _T3 | _T4 | _T5 | _T6 | _T7 | _T8 | _T9]: ...

@classmethod
@overload
def with_(cls, *patterns: Pattern[_T]) -> UnionPattern[_T]: ...

@classmethod
def with_(cls, *patterns: Pattern) -> UnionPattern:
return cls(*patterns)

def __repr__(self):
return "|".join(repr(a) for a in (*self.for_validate, *self.for_equal))

Expand Down
50 changes: 35 additions & 15 deletions nepattern/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .util import TPattern

T = TypeVar("T")
_T = TypeVar("_T")


class ValidateResult(Generic[T]):
Expand All @@ -28,21 +29,25 @@ def __init__(
__slots__ = ("_value", "_error")

def value(self) -> T:
"""获取验证结果"""
if self._value is Empty:
raise RuntimeError("cannot access value")
return self._value # type: ignore

def error(self) -> Exception | None:
"""获取验证错误"""
if self._error is not Empty:
assert isinstance(self._error, Exception)
return self._error

@property
def success(self) -> bool:
"""是否验证成功"""
return self._value is not Empty

@property
def failed(self) -> bool:
"""是否验证失败"""
return self._value is Empty

def __bool__(self): # pragma: no cover
Expand All @@ -58,11 +63,12 @@ def __repr__(self):

class Pattern(Generic[T]):
@staticmethod
def regex_match(pattern: str | TPattern, alias: str | None = None):
def regex_match(pattern: str | TPattern, alias: str | None = None) -> _RegexPattern[str]:
"""构建一个仅正则表达式匹配的 Pattern,不进行转换"""
pat = _RegexPattern(pattern, str, alias or str(pattern))

@pat.convert
def _(self, x: str):
def _(self: _RegexPattern, x: str):
mat = re.match(self.pattern, x) or re.search(self.pattern, x)
if not mat:
raise MatchFailed(
Expand All @@ -75,17 +81,18 @@ def _(self, x: str):
@staticmethod
def regex_convert(
pattern: str,
origin: type[T],
fn: Callable[[re.Match], T],
origin: type[_T],
fn: Callable[[re.Match[str]], _T],
alias: str | None = None,
allow_origin: bool = False,
):
) -> _RegexPattern[_T]:
"""构建一个正则表达式匹配的 Pattern,并提供转换函数"""
pat = _RegexPattern(pattern, origin, alias or str(pattern))
if allow_origin:
pat.accept(Union[str, origin])

@pat.convert
def _(self, x):
def _(self: _RegexPattern, x):
if isinstance(x, origin):
return x
mat = re.match(self.pattern, x) or re.search(self.pattern, x)
Expand All @@ -99,7 +106,7 @@ def _(self, x):
pat.accept(str)

@pat.convert
def _(self, x: str):
def _(self: _RegexPattern, x: str):
mat = re.match(self.pattern, x) or re.search(self.pattern, x)
if not mat:
raise MatchFailed(
Expand All @@ -110,7 +117,7 @@ def _(self, x: str):
return pat

@staticmethod
def on(obj: T):
def on(obj: _T) -> Pattern[_T]:
"""提供 DataUnit 类型的构造方法"""
from .base import DirectPattern

Expand All @@ -127,26 +134,36 @@ def __init__(self, origin: type[T] | None = None, alias: str | None = None):
self.alias = alias

self._accepts = Any
self._post_validator = lambda x: generic_isinstance(x, self.origin)
self._pre_validator = None
self._post_validator = None
self._pre_validator = (lambda x: generic_isinstance(x, self.origin)) if origin else None
self._converter = None
self._pre_validate_modified = False

def __init_subclass__(cls, **kwargs):
cls.__hash__ = Pattern.__hash__

def accept(self, input_type: Any):
"""设置接受的输入类型"""
if input_type is ...:
input_type = Any
self._accepts = input_type
if not self._pre_validate_modified:
self._pre_validator = None
return self

def pre_validate(self, func: Callable[[Any], bool]):
"""设置预验证函数 (经过 accept 后,convert 前)"""
self._pre_validator = func
self._pre_validate_modified = True
return self

def post_validate(self, func: Callable[[T], bool]):
"""设置后验证函数 (convert 后,仅当设置了 converter 才会生效)"""
self._post_validator = func
return self

def convert(self, func: Callable[[Self, Any], T | None]):
"""设置转换函数, 返回 None 时表示转换失败"""
self._converter = func
return self

Expand All @@ -165,13 +182,14 @@ def match(self, input_: Any) -> T:
raise MatchFailed(
lang.require("nepattern", "error.content").format(target=input_, expected=self.origin)
)
if self._post_validator and not self._post_validator(input_):
raise MatchFailed(
lang.require("nepattern", "error.content").format(target=input_, expected=self.origin)
)
if self._post_validator and not self._post_validator(input_):
raise MatchFailed(
lang.require("nepattern", "error.content").format(target=input_, expected=self.origin)
)
return input_

def execute(self, input_: Any) -> ValidateResult[T]:
"""执行验证"""
try:
return ValidateResult(self.match(input_))
except Exception as e:
Expand All @@ -185,7 +203,7 @@ def __str__(self):
return f"{getattr(self._accepts, '__name__', self._accepts)} -> {getattr(self.origin, '__name__', self.origin)}"

def __repr__(self):
return f"{self.__class__.__name__}({self.origin}, {self.alias})"
return f"{self.__class__.__name__}({self.origin}, {self.alias!r})"

def copy(self) -> Self:
return deepcopy(self)
Expand Down Expand Up @@ -222,6 +240,7 @@ def __init__(self, pattern: str | TPattern, origin: type[T], alias: str | None =
self.pattern = re.compile(f"^{pattern.pattern}$", pattern.flags)

def prefixed(self):
"""转为前缀型匹配"""
new = self.copy()
if isinstance(self.pattern, str):
new.pattern = self.pattern[:-1]
Expand All @@ -230,6 +249,7 @@ def prefixed(self):
return new

def suffixed(self):
"""转为后缀型匹配"""
new = self.copy()
if isinstance(self.pattern, str):
new.pattern = self.pattern[1:]
Expand Down
Loading

0 comments on commit 08242af

Please sign in to comment.