Skip to content

Commit be00292

Browse files
authored
Consistently use Python 3.10 type annotations (#263)
* Use future annotations for requirements module. * Use future annotations for utils module. * Use future annotations for base module. * Use future annotations for column module. * Use future annotations for groupby module. * Use future annotations for interval module. * Use future annotations for miscs module. * Use future annotations for nrows module. * Use future annotations for row module. * Use future annotations for stats module. * Use future annotations for uniques module. * Use future annotations for varchar module.
1 parent 1429f9d commit be00292

12 files changed

+507
-484
lines changed

src/datajudge/constraints/base.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import abc
24
from dataclasses import dataclass, field
35
from functools import lru_cache
4-
from typing import Any, Callable, Collection, List, Optional, Tuple, TypeVar, Union
6+
from typing import Any, Callable, Collection, List, Optional, TypeVar
57

68
import sqlalchemy as sa
79

@@ -16,7 +18,7 @@
1618
ToleranceGetter = Callable[[sa.engine.Engine], float]
1719

1820

19-
def uncommon_substrings(string1: str, string2: str) -> Tuple[str, str]:
21+
def uncommon_substrings(string1: str, string2: str) -> tuple[str, str]:
2022
qualifiers1 = string1.split(".")
2123
qualifiers2 = string2.split(".")
2224
if qualifiers1[0] != qualifiers2[0]:
@@ -29,29 +31,29 @@ def uncommon_substrings(string1: str, string2: str) -> Tuple[str, str]:
2931
@dataclass(frozen=True)
3032
class TestResult:
3133
outcome: bool
32-
_failure_message: Optional[str] = field(default=None, repr=False)
33-
_constraint_description: Optional[str] = field(default=None, repr=False)
34-
_factual_queries: Optional[str] = field(default=None, repr=False)
35-
_target_queries: Optional[str] = field(default=None, repr=False)
34+
_failure_message: str | None = field(default=None, repr=False)
35+
_constraint_description: str | None = field(default=None, repr=False)
36+
_factual_queries: str | None = field(default=None, repr=False)
37+
_target_queries: str | None = field(default=None, repr=False)
3638

37-
def formatted_failure_message(self, formatter: Formatter) -> Optional[str]:
39+
def formatted_failure_message(self, formatter: Formatter) -> str | None:
3840
return (
3941
formatter.fmt_str(self._failure_message) if self._failure_message else None
4042
)
4143

42-
def formatted_constraint_description(self, formatter: Formatter) -> Optional[str]:
44+
def formatted_constraint_description(self, formatter: Formatter) -> str | None:
4345
return (
4446
formatter.fmt_str(self._constraint_description)
4547
if self._constraint_description
4648
else None
4749
)
4850

4951
@property
50-
def failure_message(self) -> Optional[str]:
52+
def failure_message(self) -> str | None:
5153
return self.formatted_failure_message(DEFAULT_FORMATTER)
5254

5355
@property
54-
def constraint_description(self) -> Optional[str]:
56+
def constraint_description(self) -> str | None:
5557
return self.formatted_constraint_description(DEFAULT_FORMATTER)
5658

5759
@property
@@ -121,12 +123,12 @@ def __init__(
121123
self,
122124
ref: DataReference,
123125
*,
124-
ref2: Optional[DataReference] = None,
125-
ref_value: Optional[Any] = None,
126-
name: Optional[str] = None,
127-
output_processors: Optional[
128-
Union[OutputProcessor, List[OutputProcessor]]
129-
] = output_processor_limit,
126+
ref2: DataReference | None = None,
127+
ref_value: Any = None,
128+
name: str | None = None,
129+
output_processors: OutputProcessor
130+
| list[OutputProcessor]
131+
| None = output_processor_limit,
130132
cache_size=None,
131133
):
132134
self._check_if_valid_between_or_within(ref2, ref_value)
@@ -136,8 +138,8 @@ def __init__(
136138
self.name = name
137139
self.factual_selections: OptionalSelections = None
138140
self.target_selections: OptionalSelections = None
139-
self.factual_queries: Optional[List[str]] = None
140-
self.target_queries: Optional[List[str]] = None
141+
self.factual_queries: list[str] | None = None
142+
self.target_queries: list[str] | None = None
141143

142144
if (output_processors is not None) and (
143145
not isinstance(output_processors, list)
@@ -156,7 +158,9 @@ def _setup_caching(self):
156158
self.get_target_value = lru_cache(self.cache_size)(self.get_target_value) # type: ignore[method-assign]
157159

158160
def _check_if_valid_between_or_within(
159-
self, ref2: Optional[DataReference], ref_value: Optional[Any]
161+
self,
162+
ref2: DataReference | None,
163+
ref_value: Any,
160164
):
161165
"""Check whether exactly one of ref2 and ref_value arguments have been used."""
162166
class_name = self.__class__.__name__
@@ -228,13 +232,11 @@ def condition_string(self) -> str:
228232

229233
def retrieve(
230234
self, engine: sa.engine.Engine, ref: DataReference
231-
) -> Tuple[Any, OptionalSelections]:
235+
) -> tuple[Any, OptionalSelections]:
232236
"""Retrieve the value of interest for a DataReference from database."""
233237
pass
234238

235-
def compare(
236-
self, value_factual: Any, value_target: Any
237-
) -> Tuple[bool, Optional[str]]:
239+
def compare(self, value_factual: Any, value_target: Any) -> tuple[bool, str | None]:
238240
pass
239241

240242
def test(self, engine: sa.engine.Engine) -> TestResult:

src/datajudge/constraints/column.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations
2+
13
import abc
2-
from typing import List, Optional, Tuple, Union
34

45
import sqlalchemy as sa
56

@@ -11,7 +12,7 @@
1112
class Column(Constraint, abc.ABC):
1213
def retrieve(
1314
self, engine: sa.engine.Engine, ref: DataReference
14-
) -> Tuple[List[str], OptionalSelections]:
15+
) -> tuple[list[str], OptionalSelections]:
1516
# TODO: This does not 'belong' here. Rather, `retrieve` should be free of
1617
# side effects. This should be removed as soon as snowflake column capitalization
1718
# is fixed by snowflake-sqlalchemy.
@@ -24,15 +25,15 @@ class ColumnExistence(Column):
2425
def __init__(
2526
self,
2627
ref: DataReference,
27-
columns: List[str],
28-
name: Optional[str] = None,
28+
columns: list[str],
29+
name: str | None = None,
2930
cache_size=None,
3031
):
3132
super().__init__(ref, ref_value=columns, name=name, cache_size=cache_size)
3233

3334
def compare(
34-
self, column_names_factual: List[str], column_names_target: List[str]
35-
) -> Tuple[bool, str]:
35+
self, column_names_factual: list[str], column_names_target: list[str]
36+
) -> tuple[bool, str]:
3637
excluded_columns = list(
3738
filter(lambda c: c not in column_names_factual, column_names_target)
3839
)
@@ -45,8 +46,8 @@ def compare(
4546

4647
class ColumnSubset(Column):
4748
def compare(
48-
self, column_names_factual: List[str], column_names_target: List[str]
49-
) -> Tuple[bool, str]:
49+
self, column_names_factual: list[str], column_names_target: list[str]
50+
) -> tuple[bool, str]:
5051
missing_columns = list(
5152
filter(lambda c: c not in column_names_target, column_names_factual)
5253
)
@@ -59,8 +60,8 @@ def compare(
5960

6061
class ColumnSuperset(Column):
6162
def compare(
62-
self, column_names_factual: List[str], column_names_target: List[str]
63-
) -> Tuple[bool, str]:
63+
self, column_names_factual: list[str], column_names_target: list[str]
64+
) -> tuple[bool, str]:
6465
missing_columns = list(
6566
filter(lambda c: c not in column_names_factual, column_names_target)
6667
)
@@ -87,9 +88,9 @@ def __init__(
8788
self,
8889
ref: DataReference,
8990
*,
90-
ref2: Optional[DataReference] = None,
91-
column_type: Optional[Union[str, sa.types.TypeEngine]] = None,
92-
name: Optional[str] = None,
91+
ref2: DataReference | None = None,
92+
column_type: str | sa.types.TypeEngine | None = None,
93+
name: str | None = None,
9394
cache_size=None,
9495
):
9596
super().__init__(
@@ -103,11 +104,11 @@ def __init__(
103104

104105
def retrieve(
105106
self, engine: sa.engine.Engine, ref: DataReference
106-
) -> Tuple[sa.types.TypeEngine, OptionalSelections]:
107+
) -> tuple[sa.types.TypeEngine, OptionalSelections]:
107108
result, selections = db_access.get_column_type(engine, ref)
108109
return result, selections
109110

110-
def compare(self, column_type_factual, column_type_target) -> Tuple[bool, str]:
111+
def compare(self, column_type_factual, column_type_target) -> tuple[bool, str]:
111112
assertion_message = (
112113
f"{self.ref} is {column_type_factual} instead of {column_type_target}."
113114
)

src/datajudge/constraints/groupby.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Optional, Tuple
1+
from __future__ import annotations
2+
3+
from typing import Any
24

35
import sqlalchemy as sa
46

@@ -13,11 +15,11 @@ def __init__(
1315
ref: DataReference,
1416
aggregation_column: str,
1517
start_value: int = 0,
16-
name: Optional[str] = None,
18+
name: str | None = None,
1719
cache_size=None,
1820
*,
1921
tolerance: float = 0,
20-
ref2: Optional[DataReference] = None,
22+
ref2: DataReference | None = None,
2123
):
2224
super().__init__(ref, ref2=ref2, ref_value=object(), name=name)
2325
self.aggregation_column = aggregation_column
@@ -27,14 +29,14 @@ def __init__(
2729

2830
def retrieve(
2931
self, engine: sa.engine.Engine, ref: DataReference
30-
) -> Tuple[Any, OptionalSelections]:
32+
) -> tuple[Any, OptionalSelections]:
3133
result, selections = db_access.get_column_array_agg(
3234
engine, ref, self.aggregation_column
3335
)
3436
result = {fact[:-1]: fact[-1] for fact in result}
3537
return result, selections
3638

37-
def compare(self, factual: Any, target: Any) -> Tuple[bool, Optional[str]]:
39+
def compare(self, factual: Any, target: Any) -> tuple[bool, str | None]:
3840
def missing_from_range(values, start=0):
3941
return set(range(start, max(values) + start)) - set(values)
4042

src/datajudge/constraints/interval.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
import abc
2-
from typing import Any, List, Optional, Tuple
4+
from typing import Any
35

46
import sqlalchemy as sa
57

@@ -14,11 +16,11 @@ class IntervalConstraint(Constraint):
1416
def __init__(
1517
self,
1618
ref: DataReference,
17-
key_columns: Optional[List[str]],
18-
start_columns: List[str],
19-
end_columns: List[str],
19+
key_columns: list[str] | None,
20+
start_columns: list[str],
21+
end_columns: list[str],
2022
max_relative_n_violations: float,
21-
name: Optional[str] = None,
23+
name: str | None = None,
2224
cache_size=None,
2325
):
2426
super().__init__(ref, ref_value=object(), name=name)
@@ -44,7 +46,7 @@ def _validate_dimensions(self):
4446

4547
def retrieve(
4648
self, engine: sa.engine.Engine, ref: DataReference
47-
) -> Tuple[Tuple[int, int], OptionalSelections]:
49+
) -> tuple[tuple[int, int], OptionalSelections]:
4850
keys_ref = DataReference(
4951
data_source=self.ref.data_source,
5052
columns=self.key_columns,
@@ -69,12 +71,12 @@ class NoOverlapConstraint(IntervalConstraint):
6971
def __init__(
7072
self,
7173
ref: DataReference,
72-
key_columns: Optional[List[str]],
73-
start_columns: List[str],
74-
end_columns: List[str],
74+
key_columns: list[str] | None,
75+
start_columns: list[str],
76+
end_columns: list[str],
7577
max_relative_n_violations: float,
7678
end_included: bool,
77-
name: Optional[str] = None,
79+
name: str | None = None,
7880
cache_size=None,
7981
):
8082
self.end_included = end_included
@@ -110,12 +112,12 @@ class NoGapConstraint(IntervalConstraint):
110112
def __init__(
111113
self,
112114
ref: DataReference,
113-
key_columns: Optional[List[str]],
114-
start_columns: List[str],
115-
end_columns: List[str],
115+
key_columns: list[str] | None,
116+
start_columns: list[str],
117+
end_columns: list[str],
116118
max_relative_n_violations: float,
117119
legitimate_gap_size: float,
118-
name: Optional[str] = None,
120+
name: str | None = None,
119121
cache_size=None,
120122
):
121123
self.legitimate_gap_size = legitimate_gap_size
@@ -134,5 +136,5 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
134136
pass
135137

136138
@abc.abstractmethod
137-
def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
139+
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
138140
pass

src/datajudge/constraints/miscs.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations
2+
13
import warnings
2-
from typing import List, Optional, Set, Tuple
34

45
import sqlalchemy as sa
56

@@ -12,24 +13,24 @@ class PrimaryKeyDefinition(Constraint):
1213
def __init__(
1314
self,
1415
ref,
15-
primary_keys: List[str],
16-
name: Optional[str] = None,
16+
primary_keys: list[str],
17+
name: str | None = None,
1718
cache_size=None,
1819
):
1920
super().__init__(ref, ref_value=set(primary_keys), name=name)
2021

2122
def retrieve(
2223
self, engine: sa.engine.Engine, ref: DataReference
23-
) -> Tuple[Set[str], OptionalSelections]:
24+
) -> tuple[set[str], OptionalSelections]:
2425
if db_access.is_impala(engine):
2526
raise NotImplementedError("Primary key retrieval does not work for Impala.")
2627
values, selections = db_access.get_primary_keys(engine, self.ref)
2728
return set(values), selections
2829

2930
# Note: Exact equality!
3031
def compare(
31-
self, primary_keys_factual: Set[str], primary_keys_target: Set[str]
32-
) -> Tuple[bool, Optional[str]]:
32+
self, primary_keys_factual: set[str], primary_keys_target: set[str]
33+
) -> tuple[bool, str | None]:
3334
assertion_message = ""
3435
result = True
3536
# If both are true, just report one.
@@ -61,7 +62,7 @@ def __init__(
6162
max_duplicate_fraction: float = 0,
6263
max_absolute_n_duplicates: int = 0,
6364
infer_pk_columns: bool = False,
64-
name: Optional[str] = None,
65+
name: str | None = None,
6566
cache_size=None,
6667
):
6768
if max_duplicate_fraction != 0 and max_absolute_n_duplicates != 0:
@@ -125,7 +126,7 @@ def test(self, engine: sa.engine.Engine) -> TestResult:
125126

126127

127128
class FunctionalDependency(Constraint):
128-
def __init__(self, ref: DataReference, key_columns: List[str], **kwargs):
129+
def __init__(self, ref: DataReference, key_columns: list[str], **kwargs):
129130
super().__init__(ref, ref_value=object(), **kwargs)
130131
self.key_columns = key_columns
131132

@@ -155,10 +156,10 @@ def __init__(
155156
self,
156157
ref,
157158
*,
158-
ref2: Optional[DataReference] = None,
159-
max_null_fraction: Optional[float] = None,
159+
ref2: DataReference | None = None,
160+
max_null_fraction: float | None = None,
160161
max_relative_deviation: float = 0,
161-
name: Optional[str] = None,
162+
name: str | None = None,
162163
cache_size=None,
163164
):
164165
super().__init__(
@@ -184,7 +185,7 @@ def retrieve(self, engine: sa.engine.Engine, ref: DataReference):
184185

185186
def compare(
186187
self, missing_fraction_factual: float, missing_fracion_target: float
187-
) -> Tuple[bool, Optional[str]]:
188+
) -> tuple[bool, str | None]:
188189
threshold = missing_fracion_target * (1 + self.max_relative_deviation)
189190
result = missing_fraction_factual <= threshold
190191
assertion_text = (

0 commit comments

Comments
 (0)