diff --git a/src/datajudge/constraints/date.py b/src/datajudge/constraints/date.py index e9de5fd0..0d337bfe 100644 --- a/src/datajudge/constraints/date.py +++ b/src/datajudge/constraints/date.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import datetime as dt -from typing import Any, Optional, Tuple, Union +from typing import Any, Union import sqlalchemy as sa @@ -38,15 +40,15 @@ def __init__( ref: DataReference, use_lower_bound_reference: bool, column_type: str, - name: Optional[str] = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - min_value: Optional[str] = None, + ref2: DataReference | None = None, + min_value: str | None = None, ): self.format = get_format_from_column_type(column_type) self.use_lower_bound_reference = use_lower_bound_reference - min_date: Optional[dt.date] = None + min_date: dt.date | None = None if min_value is not None: min_date = dt.datetime.strptime(min_value, INPUT_DATE_FORMAT).date() super().__init__( @@ -59,11 +61,11 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[dt.date, OptionalSelections]: + ) -> tuple[dt.date, OptionalSelections]: result, selections = db_access.get_min(engine, ref) return convert_to_date(result, self.format), selections - def compare(self, min_factual: dt.date, min_target: dt.date) -> Tuple[bool, str]: + def compare(self, min_factual: dt.date, min_target: dt.date) -> tuple[bool, str]: if min_target is None: return TestResult(True, "") if min_factual is None: @@ -91,15 +93,15 @@ def __init__( ref: DataReference, use_upper_bound_reference: bool, column_type: str, - name: Optional[str] = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - max_value: Optional[str] = None, + ref2: DataReference | None = None, + max_value: str | None = None, ): self.format = get_format_from_column_type(column_type) self.use_upper_bound_reference = use_upper_bound_reference - max_date: Optional[dt.date] = None + max_date: dt.date | None = None if max_value is not None: max_date = dt.datetime.strptime(max_value, INPUT_DATE_FORMAT).date() super().__init__( @@ -112,11 +114,11 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[dt.date, OptionalSelections]: + ) -> tuple[dt.date, OptionalSelections]: value, selections = db_access.get_max(engine, ref) return convert_to_date(value, self.format), selections - def compare(self, max_factual: dt.date, max_target: dt.date) -> Tuple[bool, str]: + def compare(self, max_factual: dt.date, max_target: dt.date) -> tuple[bool, str]: if max_factual is None: return True, None if max_target is None: @@ -146,7 +148,7 @@ def __init__( min_fraction: float, lower_bound: str, upper_bound: str, - name: Optional[str] = None, + name: str | None = None, cache_size=None, ): super().__init__(ref, ref_value=min_fraction, name=name, cache_size=cache_size) @@ -155,14 +157,14 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float | None, OptionalSelections]: return db_access.get_fraction_between( engine, ref, self.lower_bound, self.upper_bound ) def compare( self, fraction_factual: float, fraction_target: float - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: assertion_text = ( f"{self.ref} has {fraction_factual} < " f"{fraction_target} of values between {self.lower_bound} and " @@ -175,7 +177,7 @@ def compare( class DateNoOverlap(NoOverlapConstraint): _DIMENSIONS = 1 - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: return TestResult.success() @@ -193,7 +195,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: class DateNoOverlap2d(NoOverlapConstraint): _DIMENSIONS = 2 - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: return TestResult.success() @@ -225,7 +227,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference): # executing it, one would want to list this selection here as well. return sample_selection, n_violations_selection - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: return TestResult.success() diff --git a/src/datajudge/constraints/numeric.py b/src/datajudge/constraints/numeric.py index d0b9bc35..65551e19 100644 --- a/src/datajudge/constraints/numeric.py +++ b/src/datajudge/constraints/numeric.py @@ -1,4 +1,6 @@ -from typing import Any, Optional, Tuple +from __future__ import annotations + +from typing import Any import sqlalchemy as sa @@ -12,11 +14,11 @@ class NumericMin(Constraint): def __init__( self, ref: DataReference, - name: Optional[str] = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - min_value: Optional[float] = None, + ref2: DataReference | None = None, + min_value: float | None = None, ): super().__init__( ref, @@ -28,12 +30,10 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float, OptionalSelections]: return db_access.get_min(engine, ref) - def compare( - self, min_factual: float, min_target: float - ) -> Tuple[bool, Optional[str]]: + def compare(self, min_factual: float, min_target: float) -> tuple[bool, str | None]: if min_target is None: return True, None if min_factual is None: @@ -52,11 +52,11 @@ class NumericMax(Constraint): def __init__( self, ref: DataReference, - name: Optional[str] = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - max_value: Optional[float] = None, + ref2: DataReference | None = None, + max_value: float | None = None, ): super().__init__( ref, @@ -68,12 +68,10 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float, OptionalSelections]: return db_access.get_max(engine, ref) - def compare( - self, max_factual: float, max_target: float - ) -> Tuple[bool, Optional[str]]: + def compare(self, max_factual: float, max_target: float) -> tuple[bool, str | None]: if max_factual is None: return True, None if max_target is None: @@ -95,7 +93,7 @@ def __init__( min_fraction: float, lower_bound: float, upper_bound: float, - name: Optional[str] = None, + name: str | None = None, cache_size=None, ): super().__init__(ref, ref_value=min_fraction, name=name, cache_size=cache_size) @@ -104,7 +102,7 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float | None, OptionalSelections]: return db_access.get_fraction_between( engine, ref, @@ -114,7 +112,7 @@ def retrieve( def compare( self, fraction_factual: float, fraction_target: float - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: if fraction_factual is None: return True, "Empty selection." assertion_text = ( @@ -132,11 +130,11 @@ def __init__( self, ref: DataReference, max_absolute_deviation: float, - name: Optional[str] = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - mean_value: Optional[float] = None, + ref2: DataReference | None = None, + mean_value: float | None = None, ): super().__init__( ref, @@ -149,7 +147,7 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float, OptionalSelections]: result, selections = db_access.get_mean(engine, ref) return result, selections @@ -178,13 +176,13 @@ def __init__( self, ref: DataReference, percentage: float, - max_absolute_deviation: Optional[float] = None, - max_relative_deviation: Optional[float] = None, - name: Optional[str] = None, + max_absolute_deviation: float | None = None, + max_relative_deviation: float | None = None, + name: str | None = None, cache_size=None, *, - ref2: Optional[DataReference] = None, - expected_percentile: Optional[float] = None, + ref2: DataReference | None = None, + expected_percentile: float | None = None, ): super().__init__( ref, @@ -216,13 +214,13 @@ def __init__( def retrieve( self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[float, OptionalSelections]: + ) -> tuple[float, OptionalSelections]: result, selections = db_access.get_percentile(engine, ref, self.percentage) return result, selections def compare( self, percentile_factual: float, percentile_target: float - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: abs_diff = abs(percentile_factual - percentile_target) if ( self.max_absolute_deviation is not None @@ -269,7 +267,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference): # executing it, one would want to list this selection here as well. return sample_selection, n_violations_selection - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: return TestResult.success() @@ -287,7 +285,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: class NumericNoOverlap(NoOverlapConstraint): _DIMENSIONS = 1 - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: return TestResult.success() diff --git a/src/datajudge/constraints/row.py b/src/datajudge/constraints/row.py index 8b0311e8..6e3e113a 100644 --- a/src/datajudge/constraints/row.py +++ b/src/datajudge/constraints/row.py @@ -22,6 +22,8 @@ def __init__( self.max_missing_fraction_getter = max_missing_fraction_getter def test(self, engine: sa.engine.Engine) -> TestResult: + if self.ref is None or self.ref2 is None: + raise ValueError() if db_access.is_impala(engine): raise NotImplementedError("Currently not implemented for impala.") self.max_missing_fraction = self.max_missing_fraction_getter(engine) @@ -36,6 +38,8 @@ def test(self, engine: sa.engine.Engine) -> TestResult: class RowEquality(Row): def get_factual_value(self, engine: sa.engine.Engine) -> Tuple[int, int]: + if self.ref is None or self.ref2 is None: + raise ValueError() n_rows_missing_left, selections_left = db_access.get_row_difference_count( engine, self.ref, self.ref2 ) @@ -46,6 +50,8 @@ def get_factual_value(self, engine: sa.engine.Engine) -> Tuple[int, int]: return n_rows_missing_left, n_rows_missing_right def get_target_value(self, engine: sa.engine.Engine) -> int: + if self.ref is None or self.ref2 is None: + raise ValueError() n_rows_total, selections = db_access.get_unique_count_union( engine, self.ref, self.ref2 ) @@ -80,6 +86,8 @@ def compare( class RowSubset(Row): @lru_cache(maxsize=None) def get_factual_value(self, engine: sa.engine.Engine) -> int: + if self.ref is None or self.ref2 is None: + raise ValueError() n_rows_missing, selections = db_access.get_row_difference_count( engine, self.ref, @@ -118,6 +126,8 @@ def compare( class RowSuperset(Row): def get_factual_value(self, engine: sa.engine.Engine) -> int: + if self.ref is None or self.ref2 is None: + raise ValueError() n_rows_missing, selections = db_access.get_row_difference_count( engine, self.ref2, self.ref ) @@ -125,6 +135,8 @@ def get_factual_value(self, engine: sa.engine.Engine) -> int: return n_rows_missing def get_target_value(self, engine: sa.engine.Engine) -> int: + if self.ref is None or self.ref2 is None: + raise ValueError() n_rows_total, selections = db_access.get_unique_count(engine, self.ref2) self.target_selections = selections return n_rows_total @@ -180,6 +192,8 @@ def __init__( ) def test(self, engine: sa.engine.Engine) -> TestResult: + if self.ref is None or self.ref2 is None: + raise ValueError() missing_fraction, n_rows_match, selections = db_access.get_row_mismatch( engine, self.ref, self.ref2, self.match_and_compare ) diff --git a/src/datajudge/db_access.py b/src/datajudge/db_access.py index 7a1bd8e5..1310ccaa 100644 --- a/src/datajudge/db_access.py +++ b/src/datajudge/db_access.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from collections import Counter from dataclasses import dataclass -from typing import Any, Callable, Sequence, final, overload +from typing import Any, Callable, Iterator, Sequence, final, overload import sqlalchemy as sa from sqlalchemy.sql import selectable @@ -37,11 +37,13 @@ def is_db2(engine: sa.engine.Engine) -> bool: return engine.name == "ibm_db_sa" -def get_table_columns(table, column_names): +def get_table_columns( + table: sa.Table | sa.Subquery, column_names: Sequence[str] +) -> list[sa.ColumnElement]: return [table.c[column_name] for column_name in column_names] -def apply_patches(engine: sa.engine.Engine): +def apply_patches(engine: sa.engine.Engine) -> None: """ Apply patches to e.g. specific dialect not implemented by sqlalchemy """ @@ -142,11 +144,18 @@ def __post_init__(self): f"obtained {self.reduction_operator}." ) - def _is_atomic(self): + def _is_atomic(self) -> bool: return self.raw_string is not None - def __str__(self): + def __str__(self) -> str: if self._is_atomic(): + if self.raw_string is None: + raise ValueError( + "Condition can either be instantiated atomically, with " + "the raw_query parameter, or in a composite fashion, with " + "the conditions parameter. " + "Exactly one of them needs to be provided, yet none is." + ) return self.raw_string if not self.conditions: raise ValueError("This should never happen thanks to __post__init.") @@ -154,7 +163,7 @@ def __str__(self): f"({condition})" for condition in self.conditions ) - def snowflake_str(self): + def snowflake_str(self) -> str: # Temporary method - should be removed as soon as snowflake-sqlalchemy # bug is fixed. return str(self) @@ -167,13 +176,13 @@ class MatchAndCompare: comparison_columns1: Sequence[str] comparison_columns2: Sequence[str] - def _get_matching_columns(self): + def _get_matching_columns(self) -> Iterator[tuple[str, str]]: return zip(self.matching_columns1, self.matching_columns2) - def _get_comparison_columns(self): + def _get_comparison_columns(self) -> Iterator[tuple[str, str]]: return zip(self.comparison_columns1, self.comparison_columns2) - def __str__(self): + def __str__(self) -> str: return ( f"Matched on {self.matching_columns1} and " f"{self.matching_columns2}. Compared on " @@ -181,7 +190,7 @@ def __str__(self): f"{self.comparison_columns2}." ) - def get_matching_string(self, table_variable1, table_variable2): + def get_matching_string(self, table_variable1: str, table_variable2: str) -> str: return " AND ".join( [ f"{table_variable1}.{column1} = {table_variable2}.{column2}" @@ -189,7 +198,7 @@ def get_matching_string(self, table_variable1, table_variable2): ] ) - def get_comparison_string(self, table_variable1, table_variable2): + def get_comparison_string(self, table_variable1: str, table_variable2: str) -> str: return " AND ".join( [ ( @@ -214,7 +223,7 @@ def get_clause(self, engine: sa.engine.Engine) -> FromClause: @functools.lru_cache(maxsize=1) -def get_metadata(): +def get_metadata() -> sa.MetaData: return sa.MetaData() @@ -309,7 +318,7 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(data_source={self.data_source!r}, columns={self.columns!r}, condition={self.condition!r})" - def get_selection(self, engine: sa.engine.Engine): + def get_selection(self, engine: sa.engine.Engine) -> sa.Select: clause = self.data_source.get_clause(engine) if self.columns: column_names = self.get_columns(engine) @@ -334,7 +343,7 @@ def get_selection(self, engine: sa.engine.Engine): selection = selection.with_hint(clause, "WITH (NOLOCK)") return selection - def get_column(self, engine): + def get_column(self, engine: sa.engine.Engine) -> str: """Fetch the only relevant column of a DataReference.""" if self.columns is None: raise ValueError( @@ -351,7 +360,7 @@ def get_column(self, engine): ) return columns[0] - def get_columns(self, engine) -> list[str] | None: + def get_columns(self, engine: sa.engine.Engine) -> list[str] | None: """Fetch all relevant columns of a DataReference.""" if self.columns is None: return None @@ -359,29 +368,31 @@ def get_columns(self, engine) -> list[str] | None: return lowercase_column_names(self.columns) return self.columns - def get_columns_or_pk_columns(self, engine): + def get_columns_or_pk_columns(self, engine: sa.engine.Engine) -> list[str] | None: return ( self.columns if self.columns is not None - else get_primary_keys(engine, self.data_source) + else get_primary_keys(engine, self)[0] ) - def get_column_selection_string(self): + def get_column_selection_string(self) -> str: if self.columns is None: return " * " return ", ".join(map(lambda x: f"'{x}'", self.columns)) - def get_clause_string(self, *, return_where=True): + def get_clause_string(self, *, return_where: bool = True) -> str: where_string = "WHERE " if return_where else "" return "" if self.condition is None else where_string + str(self.condition) - def __str__(self): + def __str__(self) -> str: if self.columns is None: return str(self.data_source) return f"{self.data_source}'s column(s) {self.get_column_selection_string()}" -def merge_conditions(condition1, condition2): +def merge_conditions( + condition1: Condition | None, condition2: Condition | None +) -> Condition | None: if condition1 and condition2 is None: return None if condition1 is None: @@ -391,7 +402,9 @@ def merge_conditions(condition1, condition2): return Condition(conditions=[condition1, condition2], reduction_operator="and") -def get_date_span(engine, ref, date_column_name): +def get_date_span( + engine: sa.engine.Engine, ref: DataReference, date_column_name: str +) -> tuple[float, list[sa.Select]]: if is_snowflake(engine): date_column_name = lowercase_column_names(date_column_name) subquery = ref.get_selection(engine).alias() @@ -452,6 +465,8 @@ def get_date_span(engine, ref, date_column_name): ) date_span = engine.connect().execute(selection).scalar() + if date_span is None: + raise ValueError("Date span could not be fetched.") if date_span < 0: raise ValueError( f"Date span has negative value: {date_span}. It must be positive." @@ -463,7 +478,13 @@ def get_date_span(engine, ref, date_column_name): return float(date_span), [selection] -def get_date_growth_rate(engine, ref, ref2, date_column, date_column2): +def get_date_growth_rate( + engine: sa.engine.Engine, + ref: DataReference, + ref2: DataReference, + date_column: str, + date_column2: str, +) -> tuple[float, list[sa.Select]]: date_span, selections = get_date_span(engine, ref, date_column) date_span2, selections2 = get_date_span(engine, ref2, date_column2) if date_span2 == 0: @@ -631,13 +652,13 @@ def get_interval_overlaps_nd( def _not_in_interval_condition( - main_table: sa.Table, - helper_table: sa.Table, + main_table: sa.Table | sa.Subquery, + helper_table: sa.Table | sa.Subquery, date_column: str, key_columns: list[str], start_column: str, end_column: str, -): +) -> sa.ColumnElement: return sa.not_( sa.exists( sa.select(helper_table).where( @@ -664,7 +685,7 @@ def _get_interval_gaps( make_gap_condition: Callable[ [sa.Engine, sa.Subquery, sa.Subquery, str, str, float], sa.ColumnElement[bool] ], -): +) -> tuple[sa.Select, sa.Select]: if is_snowflake(engine): if key_columns: key_columns = lowercase_column_names(key_columns) @@ -820,7 +841,7 @@ def get_date_gaps( start_column: str, end_column: str, legitimate_gap_size: float, -): +) -> tuple[sa.Select, sa.Select]: return _get_interval_gaps( engine, ref, @@ -853,7 +874,7 @@ def get_numeric_gaps( start_column: str, end_column: str, legitimate_gap_size: float = 0, -): +) -> tuple[sa.Select, sa.Select]: return _get_interval_gaps( engine, ref, @@ -869,7 +890,7 @@ def get_functional_dependency_violations( engine: sa.engine.Engine, ref: DataReference, key_columns: list[str], -): +) -> tuple[Any, list[sa.Select]]: selection = ref.get_selection(engine) uniques = selection.distinct().cte() @@ -894,19 +915,21 @@ def get_functional_dependency_violations( def get_row_count( - engine, ref, row_limit: int | None = None + engine: sa.engine.Engine, ref: DataReference, row_limit: int | None = None ) -> tuple[int, list[sa.Select]]: """Return the number of rows for a `DataReference`. If `row_limit` is given, the number of rows is capped at the limit. """ - subquery = ref.get_selection(engine) + selection = ref.get_selection(engine) if row_limit: - subquery = subquery.limit(row_limit) - subquery = subquery.alias() - selection = sa.select(sa.cast(sa.func.count(), sa.BigInteger)).select_from(subquery) - result = int(str(engine.connect().execute(selection).scalar())) - return result, [selection] + selection = selection.limit(row_limit) + subquery = selection.alias() + final_selection = sa.select(sa.cast(sa.func.count(), sa.BigInteger)).select_from( + subquery + ) + result = int(str(engine.connect().execute(final_selection).scalar())) + return result, [final_selection] def get_column( @@ -914,7 +937,7 @@ def get_column( ref: DataReference, *, aggregate_operator: Callable | None = None, -): +) -> tuple[Any, list[sa.Select]]: """ Queries the database for the values of the relevant column (as returned by `get_column(...)`). If an aggregation operation is passed, the results are aggregated accordingly @@ -936,17 +959,23 @@ def get_column( return result, [selection] -def get_min(engine, ref): +def get_min( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[Any, list[sa.Select]]: column_operator = sa.func.min return get_column(engine, ref, aggregate_operator=column_operator) -def get_max(engine, ref): +def get_max( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[Any, list[sa.Select]]: column_operator = sa.func.max return get_column(engine, ref, aggregate_operator=column_operator) -def get_mean(engine, ref): +def get_mean( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[Any, list[sa.Select]]: def column_operator(column): if is_impala(engine): return sa.func.avg(column) @@ -955,7 +984,9 @@ def column_operator(column): return get_column(engine, ref, aggregate_operator=column_operator) -def get_percentile(engine, ref, percentage): +def get_percentile( + engine: sa.engine.Engine, ref: DataReference, percentage: float +) -> tuple[float, list[sa.Select]]: row_count = "dj_row_count" row_num = "dj_row_num" column_name = ref.get_column(engine) @@ -993,25 +1024,37 @@ def get_percentile(engine, ref, percentage): percentile_selection = sa.select(counting_subquery.c[column_name]).where( counting_subquery.c[row_num] == argmin_selection.scalar_subquery() ) - result = engine.connect().execute(percentile_selection).scalar() + intermediate_result = engine.connect().execute(percentile_selection).scalar() + if intermediate_result is None: + raise ValueError("Percentile selection could not be fetched.") + result = float(intermediate_result) return result, [percentile_selection] -def get_min_length(engine, ref): +def get_min_length( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[int, list[sa.Select]]: def column_operator(column): return sa.func.min(sa.func.length(column)) return get_column(engine, ref, aggregate_operator=column_operator) -def get_max_length(engine, ref): +def get_max_length( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[int, list[sa.Select]]: def column_operator(column): return sa.func.max(sa.func.length(column)) return get_column(engine, ref, aggregate_operator=column_operator) -def get_fraction_between(engine, ref, lower_bound, upper_bound): +def get_fraction_between( + engine: sa.engine.Engine, + ref: DataReference, + lower_bound: str | float, + upper_bound: str | float, +) -> tuple[float | None, list[sa.Select]]: column = ref.get_column(engine) new_condition = Condition( conditions=[ @@ -1027,7 +1070,11 @@ def get_fraction_between(engine, ref, lower_bound, upper_bound): n_all, selections_all = get_row_count(engine, ref) n_filtered, selections_filtered = get_row_count(engine, new_ref) selections = [*selections_all, *selections_filtered] - return (n_filtered / n_all) if n_all > 0 else None, selections + if n_all is None or n_all == 0: + return (None, selections) + if n_filtered is None: + return (0.0, selections) + return n_filtered / n_all, selections def get_uniques( @@ -1035,10 +1082,10 @@ def get_uniques( ) -> tuple[Counter, list[sa.Select]]: if not ref.get_columns(engine): return Counter({}), [] - selection = ref.get_selection(engine).alias() + subquery = ref.get_selection(engine).alias() if (column_names := ref.get_columns(engine)) is None: raise ValueError("Need columns for get_uniques.") - columns = [selection.c[column_name] for column_name in column_names] + columns = [subquery.c[column_name] for column_name in column_names] selection = sa.select(*columns, sa.func.count()).group_by(*columns) def _scalar_accessor(row): @@ -1061,24 +1108,36 @@ def _tuple_accessor(row): return result, [selection] -def get_unique_count(engine, ref) -> tuple[int, list[sa.Select]]: +def get_unique_count( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[int, list[sa.Select]]: selection = ref.get_selection(engine) subquery = selection.distinct().alias() selection = sa.select(sa.func.count()).select_from(subquery) - result = int(engine.connect().execute(selection).scalar()) + intermediate_result = engine.connect().execute(selection).scalar() + if intermediate_result is None: + raise ValueError("Unique count could not be fetched.") + result = int(intermediate_result) return result, [selection] -def get_unique_count_union(engine, ref, ref2): +def get_unique_count_union( + engine: sa.engine.Engine, ref: DataReference, ref2: DataReference +) -> tuple[int, list[sa.Select]]: selection1 = ref.get_selection(engine) selection2 = ref2.get_selection(engine) subquery = sa.sql.union(selection1, selection2).alias().select().distinct().alias() selection = sa.select(sa.func.count()).select_from(subquery) - result = engine.connect().execute(selection).scalar() + intermediate_result = engine.connect().execute(selection).scalar() + if intermediate_result is None: + raise ValueError("Unique count could not be fetched.") + result = int(intermediate_result) return result, [selection] -def get_missing_fraction(engine, ref): +def get_missing_fraction( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[float, list[sa.Select]]: selection = ref.get_selection(engine).subquery() n_rows_total_selection = sa.select(sa.func.count()).select_from(selection) n_rows_missing_selection = ( @@ -1090,29 +1149,39 @@ def get_missing_fraction(engine, ref): n_rows_total = connection.execute(n_rows_total_selection).scalar() n_rows_missing = connection.execute(n_rows_missing_selection).scalar() + if n_rows_total is None or n_rows_missing is None: + return (0, [n_rows_total_selection, n_rows_missing_selection]) return ( n_rows_missing / n_rows_total, [n_rows_total_selection, n_rows_missing_selection], ) -def get_column_names(engine, ref): +def get_column_names( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[list[str], None]: table = ref.data_source.get_clause(engine) return [column.name for column in table.columns], None -def get_column_type(engine, ref): +def get_column_type(engine: sa.engine.Engine, ref: DataReference) -> tuple[Any, None]: table = ref.get_selection(engine).alias() column_type = next(iter(table.columns)).type return column_type, None -def get_primary_keys(engine, ref): +def get_primary_keys( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[list[str], None]: table = ref.data_source.get_clause(engine) - return [column.name for column in table.primary_key.columns], None + # Kevin, 25/02/04 + # Mypy complains about the following for a reason I can't follow. + return [column.name for column in table.primary_key.columns], None # type: ignore -def get_row_difference_sample(engine, ref, ref2): +def get_row_difference_sample( + engine: sa.engine.Engine, ref: DataReference, ref2: DataReference +) -> tuple[Any, list[sa.Select]]: selection1 = ref.get_selection(engine) selection2 = ref2.get_selection(engine) selection = sa.sql.except_(selection1, selection2).alias().select() @@ -1120,18 +1189,28 @@ def get_row_difference_sample(engine, ref, ref2): return result, [selection] -def get_row_difference_count(engine, ref, ref2): +def get_row_difference_count( + engine: sa.engine.Engine, ref: DataReference, ref2: DataReference +) -> tuple[int, list[sa.Select]]: selection1 = ref.get_selection(engine) selection2 = ref2.get_selection(engine) subquery = ( sa.sql.except_(selection1, selection2).alias().select().distinct().alias() ) selection = sa.select(sa.func.count()).select_from(subquery) - result = engine.connect().execute(selection).scalar() + result_intermediate = engine.connect().execute(selection).scalar() + if result_intermediate is None: + raise ValueError("Could not get the row difference count.") + result = int(result_intermediate) return result, [selection] -def get_row_mismatch(engine, ref, ref2, match_and_compare): +def get_row_mismatch( + engine: sa.engine.Engine, + ref: DataReference, + ref2: DataReference, + match_and_compare: MatchAndCompare, +) -> tuple[float, int, list[sa.Select]]: subselection1 = ref.get_selection(engine).alias() subselection2 = ref2.get_selection(engine).alias() @@ -1169,12 +1248,18 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare): selection_n_rows = sa.select(sa.func.count()).select_from( subselection1.join(subselection2, match) ) - result_mismatch = engine.connect().execute(selection_difference).scalar() - result_n_rows = engine.connect().execute(selection_n_rows).scalar() + result_mismatch_intermediate = ( + engine.connect().execute(selection_difference).scalar() + ) + result_n_rows_intermediate = engine.connect().execute(selection_n_rows).scalar() + if result_mismatch_intermediate is None or result_n_rows_intermediate is None: + raise ValueError("Could not fetch number of mismatches.") + result_mismatch = float(result_mismatch_intermediate) + result_n_rows = int(result_n_rows_intermediate) return result_mismatch, result_n_rows, [selection_difference, selection_n_rows] -def duplicates(subquery: sa.sql.selectable.Subquery) -> sa.sql.selectable.Select: +def duplicates(subquery: sa.sql.selectable.Subquery) -> sa.Select: aggregate_subquery = ( sa.select(subquery, sa.func.count().label("n_copies")) .select_from(subquery) @@ -1195,7 +1280,9 @@ def duplicates(subquery: sa.sql.selectable.Subquery) -> sa.sql.selectable.Select return duplicate_selection -def get_duplicate_sample(engine: sa.engine.Engine, ref: DataReference) -> tuple: +def get_duplicate_sample( + engine: sa.engine.Engine, ref: DataReference +) -> tuple[Any, list[sa.Select]]: initial_selection = ref.get_selection(engine).alias() duplicate_selection = duplicates(initial_selection) result = engine.connect().execute(duplicate_selection).first() @@ -1204,7 +1291,7 @@ def get_duplicate_sample(engine: sa.engine.Engine, ref: DataReference) -> tuple: def column_array_agg_query( engine: sa.engine.Engine, ref: DataReference, aggregation_column: str -): +) -> list[sa.Select]: clause = ref.data_source.get_clause(engine) if not (column_names := ref.get_columns(engine)): raise ValueError("There must be a column to group by") @@ -1216,7 +1303,7 @@ def column_array_agg_query( return [selection] -def snowflake_parse_variant_column(value: str): +def snowflake_parse_variant_column(value: str) -> dict: # Snowflake returns non-primitive columns such as arrays as JSON string, # but we want them in their deserialized form. return json.loads(value) @@ -1224,7 +1311,7 @@ def snowflake_parse_variant_column(value: str): def get_column_array_agg( engine: sa.engine.Engine, ref: DataReference, aggregation_column: str -): +) -> tuple[Any, list[sa.Select]]: selections = column_array_agg_query(engine, ref, aggregation_column) result: Sequence[sa.engine.row.Row[Any]] | list[tuple[Any, ...]] = ( engine.connect().execute(selections[0]).fetchall() @@ -1237,7 +1324,9 @@ def get_column_array_agg( return result, selections -def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str): +def _cdf_selection( + engine: sa.engine.Engine, ref: DataReference, cdf_label: str, value_label: str +) -> sa.Subquery: """Create an empirical cumulative distribution function values. Concretely, create a selection with values from ``value_label`` as well as @@ -1266,8 +1355,12 @@ def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str) def _cross_cdf_selection( - engine, ref1: DataReference, ref2: DataReference, cdf_label: str, value_label: str -): + engine: sa.engine.Engine, + ref1: DataReference, + ref2: DataReference, + cdf_label: str, + value_label: str, +) -> tuple[sa.Select, str, str]: """Create a cross cumulative distribution function selection given two samples. Concretely, both ``DataReference``s are expected to have specified a single relevant column. @@ -1354,7 +1447,7 @@ def get_ks_2sample( engine: sa.engine.Engine, ref1: DataReference, ref2: DataReference, -): +) -> tuple[float, list[sa.Select]]: """ Run the query for the two-sample Kolmogorov-Smirnov test and return the test statistic d. @@ -1379,15 +1472,24 @@ def get_ks_2sample( with engine.connect() as connection: d_statistic = connection.execute(final_selection).scalar() + if d_statistic is None: + raise ValueError("Could not compute d statistic.") + return d_statistic, [final_selection] -def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples): - subquery = ref.get_selection(engine) +def get_regex_violations( + engine: sa.engine.Engine, + ref: DataReference, + aggregated: bool, + regex: str, + n_counterexamples: int, +) -> tuple[tuple[int, Any], list[sa.Select]]: + original_selection = ref.get_selection(engine) column = ref.get_column(engine) if aggregated: - subquery = subquery.distinct() - subquery = subquery.subquery() + original_selection = original_selection.distinct() + subquery = original_selection.subquery() if is_impala(engine): violation_selection = sa.select(subquery.c[column]).where( sa.not_(sa.func.regexp_like(subquery.c[column], regex)) @@ -1416,6 +1518,8 @@ def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples): with engine.connect() as connection: n_violations_result = connection.execute(n_violations_selection).scalar() + if n_violations_result is None: + n_violations_result = 0 if counterexamples_selection is None: counterexamples = [] else: