1
+ from __future__ import annotations
2
+
1
3
import abc
2
4
from dataclasses import dataclass , field
3
5
from functools import lru_cache
4
- from typing import Any , Callable , Collection , List , Optional , Tuple , TypeVar , Union
6
+ from typing import Any , Callable , Collection , List , Optional , TypeVar
5
7
6
8
import sqlalchemy as sa
7
9
16
18
ToleranceGetter = Callable [[sa .engine .Engine ], float ]
17
19
18
20
19
- def uncommon_substrings (string1 : str , string2 : str ) -> Tuple [str , str ]:
21
+ def uncommon_substrings (string1 : str , string2 : str ) -> tuple [str , str ]:
20
22
qualifiers1 = string1 .split ("." )
21
23
qualifiers2 = string2 .split ("." )
22
24
if qualifiers1 [0 ] != qualifiers2 [0 ]:
@@ -29,29 +31,29 @@ def uncommon_substrings(string1: str, string2: str) -> Tuple[str, str]:
29
31
@dataclass (frozen = True )
30
32
class TestResult :
31
33
outcome : bool
32
- _failure_message : Optional [ str ] = field (default = None , repr = False )
33
- _constraint_description : Optional [ str ] = field (default = None , repr = False )
34
- _factual_queries : Optional [ str ] = field (default = None , repr = False )
35
- _target_queries : Optional [ str ] = field (default = None , repr = False )
34
+ _failure_message : str | None = field (default = None , repr = False )
35
+ _constraint_description : str | None = field (default = None , repr = False )
36
+ _factual_queries : str | None = field (default = None , repr = False )
37
+ _target_queries : str | None = field (default = None , repr = False )
36
38
37
- def formatted_failure_message (self , formatter : Formatter ) -> Optional [ str ] :
39
+ def formatted_failure_message (self , formatter : Formatter ) -> str | None :
38
40
return (
39
41
formatter .fmt_str (self ._failure_message ) if self ._failure_message else None
40
42
)
41
43
42
- def formatted_constraint_description (self , formatter : Formatter ) -> Optional [ str ] :
44
+ def formatted_constraint_description (self , formatter : Formatter ) -> str | None :
43
45
return (
44
46
formatter .fmt_str (self ._constraint_description )
45
47
if self ._constraint_description
46
48
else None
47
49
)
48
50
49
51
@property
50
- def failure_message (self ) -> Optional [ str ] :
52
+ def failure_message (self ) -> str | None :
51
53
return self .formatted_failure_message (DEFAULT_FORMATTER )
52
54
53
55
@property
54
- def constraint_description (self ) -> Optional [ str ] :
56
+ def constraint_description (self ) -> str | None :
55
57
return self .formatted_constraint_description (DEFAULT_FORMATTER )
56
58
57
59
@property
@@ -121,12 +123,12 @@ def __init__(
121
123
self ,
122
124
ref : DataReference ,
123
125
* ,
124
- ref2 : Optional [ DataReference ] = None ,
125
- ref_value : Optional [ Any ] = None ,
126
- name : Optional [ str ] = None ,
127
- output_processors : Optional [
128
- Union [OutputProcessor , List [ OutputProcessor ] ]
129
- ] = output_processor_limit ,
126
+ ref2 : DataReference | None = None ,
127
+ ref_value : Any = None ,
128
+ name : str | None = None ,
129
+ output_processors : OutputProcessor
130
+ | list [OutputProcessor ]
131
+ | None = output_processor_limit ,
130
132
cache_size = None ,
131
133
):
132
134
self ._check_if_valid_between_or_within (ref2 , ref_value )
@@ -136,8 +138,8 @@ def __init__(
136
138
self .name = name
137
139
self .factual_selections : OptionalSelections = None
138
140
self .target_selections : OptionalSelections = None
139
- self .factual_queries : Optional [ List [ str ]] = None
140
- self .target_queries : Optional [ List [ str ]] = None
141
+ self .factual_queries : list [ str ] | None = None
142
+ self .target_queries : list [ str ] | None = None
141
143
142
144
if (output_processors is not None ) and (
143
145
not isinstance (output_processors , list )
@@ -156,7 +158,9 @@ def _setup_caching(self):
156
158
self .get_target_value = lru_cache (self .cache_size )(self .get_target_value ) # type: ignore[method-assign]
157
159
158
160
def _check_if_valid_between_or_within (
159
- self , ref2 : Optional [DataReference ], ref_value : Optional [Any ]
161
+ self ,
162
+ ref2 : DataReference | None ,
163
+ ref_value : Any ,
160
164
):
161
165
"""Check whether exactly one of ref2 and ref_value arguments have been used."""
162
166
class_name = self .__class__ .__name__
@@ -228,13 +232,11 @@ def condition_string(self) -> str:
228
232
229
233
def retrieve (
230
234
self , engine : sa .engine .Engine , ref : DataReference
231
- ) -> Tuple [Any , OptionalSelections ]:
235
+ ) -> tuple [Any , OptionalSelections ]:
232
236
"""Retrieve the value of interest for a DataReference from database."""
233
237
pass
234
238
235
- def compare (
236
- self , value_factual : Any , value_target : Any
237
- ) -> Tuple [bool , Optional [str ]]:
239
+ def compare (self , value_factual : Any , value_target : Any ) -> tuple [bool , str | None ]:
238
240
pass
239
241
240
242
def test (self , engine : sa .engine .Engine ) -> TestResult :
0 commit comments