Skip to content

Commit

Permalink
Annotate remaining function parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Feb 4, 2025
1 parent b1a47a6 commit db4f013
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ def is_db2(engine: sa.engine.Engine) -> bool:
return engine.name == "ibm_db_sa"


def get_table_columns(table, column_names):
def get_table_columns(
table: sa.Table | sa.Subquery, column_names: Sequence[str]
) -> list:
return [table.c[column_name] for column_name in column_names]


def apply_patches(engine: sa.engine.Engine):
def apply_patches(engine: sa.engine.Engine) -> None:
"""
Apply patches to e.g. specific dialect not implemented by sqlalchemy
"""
Expand Down Expand Up @@ -142,19 +144,26 @@ def __post_init__(self):
f"obtained {self.reduction_operator}."
)

def _is_atomic(self):
def _is_atomic(self) -> bool:
return self.raw_string is not None

def __str__(self):
def __str__(self) -> str:
if self._is_atomic():
if self.raw_string is None:
raise ValueError(

Check warning on line 153 in src/datajudge/db_access.py

View check run for this annotation

Codecov / codecov/patch

src/datajudge/db_access.py#L153

Added line #L153 was not covered by tests
"Condition can either be instantiated atomically, with "
"the raw_query parameter, or in a composite fashion, with "
"the conditions parameter. "
"Exactly one of them needs to be provided, yet none is."
)
return self.raw_string
if not self.conditions:
raise ValueError("This should never happen thanks to __post__init.")
return f" {self.reduction_operator} ".join(
f"({condition})" for condition in self.conditions
)

def snowflake_str(self):
def snowflake_str(self) -> str:
# Temporary method - should be removed as soon as snowflake-sqlalchemy
# bug is fixed.
return str(self)
Expand All @@ -173,23 +182,23 @@ def _get_matching_columns(self):
def _get_comparison_columns(self):
return zip(self.comparison_columns1, self.comparison_columns2)

def __str__(self):
def __str__(self) -> str:
return (
f"Matched on {self.matching_columns1} and "
f"{self.matching_columns2}. Compared on "
f"{self.comparison_columns1} and "
f"{self.comparison_columns2}."
)

def get_matching_string(self, table_variable1, table_variable2):
def get_matching_string(self, table_variable1: str, table_variable2: str) -> str:
return " AND ".join(
[
f"{table_variable1}.{column1} = {table_variable2}.{column2}"
for (column1, column2) in self._get_matching_columns()
]
)

def get_comparison_string(self, table_variable1, table_variable2):
def get_comparison_string(self, table_variable1: str, table_variable2: str) -> str:
return " AND ".join(
[
(
Expand All @@ -214,7 +223,7 @@ def get_clause(self, engine: sa.engine.Engine) -> FromClause:


@functools.lru_cache(maxsize=1)
def get_metadata():
def get_metadata() -> sa.MetaData:
return sa.MetaData()


Expand Down Expand Up @@ -334,7 +343,7 @@ def get_selection(self, engine: sa.engine.Engine):
selection = selection.with_hint(clause, "WITH (NOLOCK)")
return selection

def get_column(self, engine: sa.engine.Engine):
def get_column(self, engine: sa.engine.Engine) -> str:
"""Fetch the only relevant column of a DataReference."""
if self.columns is None:
raise ValueError(
Expand Down Expand Up @@ -379,7 +388,9 @@ def __str__(self):
return f"{self.data_source}'s column(s) {self.get_column_selection_string()}"


def merge_conditions(condition1, condition2):
def merge_conditions(
condition1: Condition | None, condition2: Condition | None
) -> Condition | None:
if condition1 and condition2 is None:
return None
if condition1 is None:
Expand All @@ -389,7 +400,7 @@ def merge_conditions(condition1, condition2):
return Condition(conditions=[condition1, condition2], reduction_operator="and")


def get_date_span(engine: sa.engine.Engine, ref: DataReference, date_column_name):
def get_date_span(engine: sa.engine.Engine, ref: DataReference, date_column_name: str):
if is_snowflake(engine):
date_column_name = lowercase_column_names(date_column_name)
subquery = ref.get_selection(engine).alias()
Expand Down Expand Up @@ -961,7 +972,7 @@ def column_operator(column):
return get_column(engine, ref, aggregate_operator=column_operator)


def get_percentile(engine: sa.engine.Engine, ref: DataReference, percentage):
def get_percentile(engine: sa.engine.Engine, ref: DataReference, percentage: float):
row_count = "dj_row_count"
row_num = "dj_row_num"
column_name = ref.get_column(engine)
Expand Down Expand Up @@ -1018,7 +1029,10 @@ def column_operator(column):


def get_fraction_between(
engine: sa.engine.Engine, ref: DataReference, lower_bound, upper_bound
engine: sa.engine.Engine,
ref: DataReference,
lower_bound: str | float,
upper_bound: str | float,
):
column = ref.get_column(engine)
new_condition = Condition(
Expand Down Expand Up @@ -1414,7 +1428,11 @@ def get_ks_2sample(


def get_regex_violations(
engine: sa.engine.Engine, ref: DataReference, aggregated, regex, n_counterexamples
engine: sa.engine.Engine,
ref: DataReference,
aggregated: bool,
regex: str,
n_counterexamples: int,
):
subquery = ref.get_selection(engine)
column = ref.get_column(engine)
Expand Down

0 comments on commit db4f013

Please sign in to comment.