Skip to content

Commit 1e23940

Browse files
kkleinpavelzw
andauthored
Type-check package (#229)
* Draft usage of mypy. * Apply pchs. * Draft usage of pixi. * Adapt more integration tests. * Adapt bigquery integration test task. * Migrate docs deployment to pixi. * Update build and publish. * Run postinstall for readthedocs. * Adapt liniting to rely on pixi. * Add pytest as dependency. * Adapt readthedocs configuration. * Update development instructions. * Switch to pixi in helper script. * Minor adadptations. * Introduce mypy CI step. * Add changelog entry. * Fix usage of condarc file. * Adapt development.rst. * Use pixi for mypy job. * Fix yml syntax. * Create pixi environment for mypy. * Remove impala step for debugging. * Add type packages. * Comment out impala run for debugging. * Move flit to host-dependencies. * Remove redundant dependency. * Update .github/workflows/ci.yaml Co-authored-by: Pavel Zwerschke <pavel.zwerschke@quantco.com> * Add color to ci tests. * Update .github/workflows/ci.yaml Co-authored-by: Pavel Zwerschke <pavel.zwerschke@quantco.com> * Update .github/workflows/ci.yaml Co-authored-by: Pavel Zwerschke <pavel.zwerschke@quantco.com> * Remove configurations which were erroneously added. * Remove configurations which were erroneously added. * Consistently use ubuntu-latest. * Update .github/workflows/ci.yaml Co-authored-by: Pavel Zwerschke <pavel.zwerschke@quantco.com> * Also run unit tests on macos and windows. * Fix impala test syntax. * Downgrade sqlalchemy version for impala tests. * Downgrade sqlalchemy version for impala tests. * Disable BigQuery tests for now. --------- Co-authored-by: Pavel Zwerschke <pavel.zwerschke@quantco.com>
1 parent 73c1792 commit 1e23940

24 files changed

+963
-497
lines changed

.github/workflows/ci.yaml

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ jobs:
2929
- name: pre-commit
3030
run: pixi run pre-commit-run --color=always --show-diff-on-failure
3131

32+
mypy-type-checks:
33+
name: Mypy Type Checks
34+
runs-on: ubuntu-latest
35+
steps:
36+
- name: Checkout branch
37+
uses: actions/checkout@v4
38+
- name: Set up pixi
39+
uses: prefix-dev/setup-pixi@v0.8.1
40+
with:
41+
environments: default lint
42+
- name: mypy
43+
run: |
44+
pixi run -e mypy postinstall
45+
pixi run -e mypy mypy .
46+
3247
unit-tests:
3348
name: "unit tests"
3449
strategy:

CHANGELOG.rst

+6
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@ Changelog
1111
------------------
1212

1313
**New features**
14+
1415
- Added styling for assertion messages. See :ref:`assertion-message-styling` for more information.
1516

17+
**Other changes**
18+
19+
- Provide a ``py.typed`` file.
20+
21+
1622
1.8.0 - 2023.06.16
1723
------------------
1824

pixi.lock

+621-236
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

+9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ sqlalchemy = "2.*"
8888
pytest-cov = "*"
8989
pytest-xdist = "*"
9090

91+
[feature.mypy.dependencies]
92+
mypy = "*"
93+
types-setuptools = "*"
94+
types-colorama = "*"
95+
pandas-stubs = "*"
96+
types-jinja2 = "*"
97+
9198
[feature.lint.dependencies]
9299
pre-commit = "*"
93100
black = "*"
@@ -150,3 +157,5 @@ impala-py38 = ["impala", "py38", "sa1", "test"]
150157
impala-sa1 = ["impala", "sa1", "test"]
151158

152159
lint = { features = ["lint"], no-default-feature = true }
160+
161+
mypy = ["mypy"]

pyproject.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ skip_glob = '\.eggs/*,\.git/*,\.venv/*,build/*,dist/*'
4646
default_section = 'THIRDPARTY'
4747

4848
[tool.mypy]
49-
# Temporary fix.
50-
no_implicit_optional = false
49+
no_implicit_optional = true
5150
allow_empty_bodies = true
51+
check_untyped_defs = true
52+
53+
[[tool.mypy.overrides]]
54+
module = ["scipy.*", "impala.*", "pytest_html"]
55+
ignore_missing_imports = true

src/datajudge/constraints/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def __init__(
117117
self,
118118
ref: DataReference,
119119
*,
120-
ref2=None,
121-
ref_value: Any = None,
122-
name: str = None,
120+
ref2: Optional[DataReference] = None,
121+
ref_value: Optional[Any] = None,
122+
name: Optional[str] = None,
123123
output_processors: Optional[
124124
Union[OutputProcessor, List[OutputProcessor]]
125125
] = output_processor_limit,

src/datajudge/constraints/column.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def retrieve(
2121

2222

2323
class ColumnExistence(Column):
24-
def __init__(self, ref: DataReference, columns: List[str], name: str = None):
24+
def __init__(
25+
self, ref: DataReference, columns: List[str], name: Optional[str] = None
26+
):
2527
super().__init__(ref, ref_value=columns, name=name)
2628

2729
def compare(

src/datajudge/constraints/date.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def __init__(
3838
ref: DataReference,
3939
use_lower_bound_reference: bool,
4040
column_type: str,
41-
name: str = None,
41+
name: Optional[str] = None,
4242
*,
43-
ref2: DataReference = None,
44-
min_value: str = None,
43+
ref2: Optional[DataReference] = None,
44+
min_value: Optional[str] = None,
4545
):
4646
self.format = get_format_from_column_type(column_type)
4747
self.use_lower_bound_reference = use_lower_bound_reference
@@ -84,10 +84,10 @@ def __init__(
8484
ref: DataReference,
8585
use_upper_bound_reference: bool,
8686
column_type: str,
87-
name: str = None,
87+
name: Optional[str] = None,
8888
*,
89-
ref2: DataReference = None,
90-
max_value: str = None,
89+
ref2: Optional[DataReference] = None,
90+
max_value: Optional[str] = None,
9191
):
9292
self.format = get_format_from_column_type(column_type)
9393
self.use_upper_bound_reference = use_upper_bound_reference
@@ -132,7 +132,7 @@ def __init__(
132132
min_fraction: float,
133133
lower_bound: str,
134134
upper_bound: str,
135-
name: str = None,
135+
name: Optional[str] = None,
136136
):
137137
super().__init__(ref, ref_value=min_fraction, name=name)
138138
self.lower_bound = lower_bound

src/datajudge/constraints/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def __init__(
1313
ref: DataReference,
1414
aggregation_column: str,
1515
start_value: int = 0,
16-
name: str = None,
16+
name: Optional[str] = None,
1717
*,
1818
tolerance: float = 0,
19-
ref2: DataReference = None,
19+
ref2: Optional[DataReference] = None,
2020
):
2121
super().__init__(ref, ref2=ref2, ref_value=object(), name=name)
2222
self.aggregation_column = aggregation_column

src/datajudge/constraints/interval.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
start_columns: List[str],
1919
end_columns: List[str],
2020
max_relative_n_violations: float,
21-
name: str = None,
21+
name: Optional[str] = None,
2222
):
2323
super().__init__(ref, ref_value=object(), name=name)
2424
self.key_columns = key_columns
@@ -56,7 +56,9 @@ def retrieve(
5656
sample_selection, n_violations_selection = self.select(engine, ref)
5757
with engine.connect() as connection:
5858
self.sample = connection.execute(sample_selection).first()
59-
n_violation_keys = connection.execute(n_violations_selection).scalar()
59+
n_violation_keys = int(
60+
str(connection.execute(n_violations_selection).scalar())
61+
)
6062

6163
selections = [*n_keys_selections, sample_selection, n_violations_selection]
6264
return (n_violation_keys, n_distinct_key_values), selections
@@ -97,7 +99,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
9799
return sample_selection, n_violations_selection
98100

99101
@abc.abstractmethod
100-
def compare(self, engine: sa.engine.Engine, ref: DataReference):
102+
def compare(self, factual: Any, target: Any):
101103
pass
102104

103105

src/datajudge/constraints/miscs.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class PrimaryKeyDefinition(Constraint):
12-
def __init__(self, ref, primary_keys: List[str], name: str = None):
12+
def __init__(self, ref, primary_keys: List[str], name: Optional[str] = None):
1313
super().__init__(ref, ref_value=set(primary_keys), name=name)
1414

1515
def retrieve(
@@ -55,7 +55,7 @@ def __init__(
5555
max_duplicate_fraction: float = 0,
5656
max_absolute_n_duplicates: int = 0,
5757
infer_pk_columns: bool = False,
58-
name: str = None,
58+
name: Optional[str] = None,
5959
):
6060
if max_duplicate_fraction != 0 and max_absolute_n_duplicates != 0:
6161
raise ValueError(
@@ -148,10 +148,10 @@ def __init__(
148148
self,
149149
ref,
150150
*,
151-
ref2: DataReference = None,
152-
max_null_fraction: float = None,
151+
ref2: Optional[DataReference] = None,
152+
max_null_fraction: Optional[float] = None,
153153
max_relative_deviation: float = 0,
154-
name: str = None,
154+
name: Optional[str] = None,
155155
):
156156
super().__init__(ref, ref2=ref2, ref_value=max_null_fraction, name=name)
157157
if max_null_fraction is not None and not (0 <= max_null_fraction <= 1):

src/datajudge/constraints/nrows.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Tuple
2+
from typing import Optional, Tuple
33

44
import sqlalchemy as sa
55

@@ -11,7 +11,12 @@
1111

1212
class NRows(Constraint, abc.ABC):
1313
def __init__(
14-
self, ref, *, ref2: DataReference = None, n_rows: int = None, name: str = None
14+
self,
15+
ref,
16+
*,
17+
ref2: Optional[DataReference] = None,
18+
n_rows: Optional[int] = None,
19+
name: Optional[str] = None,
1520
):
1621
super().__init__(ref, ref2=ref2, ref_value=n_rows, name=name)
1722

@@ -79,7 +84,7 @@ def __init__(
7984
ref: DataReference,
8085
ref2: DataReference,
8186
max_relative_loss_getter: ToleranceGetter,
82-
name: str = None,
87+
name: Optional[str] = None,
8388
):
8489
super().__init__(ref, ref2=ref2, name=name)
8590
self.max_relative_loss_getter = max_relative_loss_getter
@@ -110,7 +115,7 @@ def __init__(
110115
ref: DataReference,
111116
ref2: DataReference,
112117
max_relative_gain_getter: ToleranceGetter,
113-
name: str = None,
118+
name: Optional[str] = None,
114119
):
115120
super().__init__(ref, ref2=ref2, name=name)
116121
self.max_relative_gain_getter = max_relative_gain_getter
@@ -141,7 +146,7 @@ def __init__(
141146
ref: DataReference,
142147
ref2: DataReference,
143148
min_relative_gain_getter: ToleranceGetter,
144-
name: str = None,
149+
name: Optional[str] = None,
145150
):
146151
super().__init__(ref, ref2=ref2, name=name)
147152
self.min_relative_gain_getter = min_relative_gain_getter

src/datajudge/constraints/numeric.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ class NumericMin(Constraint):
1212
def __init__(
1313
self,
1414
ref: DataReference,
15-
name: str = None,
15+
name: Optional[str] = None,
1616
*,
17-
ref2: DataReference = None,
18-
min_value: float = None,
17+
ref2: Optional[DataReference] = None,
18+
min_value: Optional[float] = None,
1919
):
2020
super().__init__(ref, ref2=ref2, ref_value=min_value, name=name)
2121

@@ -45,10 +45,10 @@ class NumericMax(Constraint):
4545
def __init__(
4646
self,
4747
ref: DataReference,
48-
name: str = None,
48+
name: Optional[str] = None,
4949
*,
50-
ref2: DataReference = None,
51-
max_value: float = None,
50+
ref2: Optional[DataReference] = None,
51+
max_value: Optional[float] = None,
5252
):
5353
super().__init__(
5454
ref,
@@ -86,7 +86,7 @@ def __init__(
8686
min_fraction: float,
8787
lower_bound: float,
8888
upper_bound: float,
89-
name: str = None,
89+
name: Optional[str] = None,
9090
):
9191
super().__init__(ref, ref_value=min_fraction, name=name)
9292
self.lower_bound = lower_bound
@@ -122,10 +122,10 @@ def __init__(
122122
self,
123123
ref: DataReference,
124124
max_absolute_deviation: float,
125-
name: str = None,
125+
name: Optional[str] = None,
126126
*,
127-
ref2: DataReference = None,
128-
mean_value: float = None,
127+
ref2: Optional[DataReference] = None,
128+
mean_value: Optional[float] = None,
129129
):
130130
super().__init__(
131131
ref,
@@ -170,8 +170,8 @@ def __init__(
170170
max_relative_deviation: Optional[float] = None,
171171
name: Optional[str] = None,
172172
*,
173-
ref2: DataReference = None,
174-
expected_percentile: float = None,
173+
ref2: Optional[DataReference] = None,
174+
expected_percentile: Optional[float] = None,
175175
):
176176
super().__init__(
177177
ref,

src/datajudge/constraints/row.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
ref: DataReference,
1616
ref2: DataReference,
1717
max_missing_fraction_getter: ToleranceGetter,
18-
name: str = None,
18+
name: Optional[str] = None,
1919
):
2020
super().__init__(ref, ref2=ref2, name=name)
2121
self.max_missing_fraction_getter = max_missing_fraction_getter
@@ -62,6 +62,8 @@ def compare(
6262
result = missing_fraction <= self.max_missing_fraction
6363
if result:
6464
return result, None
65+
if self.ref2 is None:
66+
raise ValueError("RowEquality constraint requires ref2.")
6567
if n_rows_missing_left > 0:
6668
sample_string = format_sample(self.ref1_minus_ref2_sample, self.ref2)
6769
else:
@@ -139,6 +141,8 @@ def compare(
139141
result = missing_fraction <= self.max_missing_fraction
140142
if result:
141143
return result, None
144+
if self.ref2 is None:
145+
raise ValueError("RowSuperset constraint requires ref2.")
142146
sample_string = format_sample(self.ref2_minus_ref1_sample, self.ref2)
143147
assertion_message = (
144148
f"{missing_fraction} > "
@@ -161,7 +165,7 @@ def __init__(
161165
comparison_columns1: List[str],
162166
comparison_columns2: List[str],
163167
max_missing_fraction_getter: ToleranceGetter,
164-
name: str = None,
168+
name: Optional[str] = None,
165169
):
166170
super().__init__(
167171
ref,

src/datajudge/constraints/stats.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
ref: DataReference,
1616
ref2: DataReference,
1717
significance_level: float = 0.05,
18-
name: str = None,
18+
name: Optional[str] = None,
1919
):
2020
self.significance_level = significance_level
2121
super().__init__(ref, ref2=ref2, name=name)
@@ -97,6 +97,8 @@ def calculate_statistic(
9797
return d_statistic, p_value, n_samples, m_samples, selections
9898

9999
def test(self, engine: sa.engine.Engine) -> TestResult:
100+
if self.ref2 is None:
101+
raise ValueError("KolmogorovSmirnov2Sample requires ref2.")
100102
(
101103
d_statistic,
102104
p_value,

0 commit comments

Comments
 (0)