Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added cohort constraint to valueset levels to address issue #304 #306

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 75 additions & 66 deletions icees_api/features/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,47 +38,37 @@ def get_digest(*args):
return c.digest()


def op_dict(k, v, table_):
try:
value = table_.c[k]
except KeyError:
raise HTTPException(status_code=400, detail=f"No feature named '{k}'")
# python_type = value.type.python_type
# if v["operator"] == "in":
# values = v["values"]
# elif v["operator"] == "between":
# values = [v["value_a"], v["value_b"]]
# else:
# values = [v["value"]]
# options = features_dict[table_name][k].get("enum", None)
# for value in values:
# if not isinstance(value, python_type):
# raise HTTPException(
# status_code=400,
# detail="'{feature}' should be of type {type}, but {value} is not".format(
# value=value,
# feature=k,
# type=python_type,
# )
# )
# if options is not None and value not in options:
# raise HTTPException(
# status_code=400,
# detail="{value} is not in {options}".format(
# value=value,
# options=options
# )
# )

def op_dict(k, v, table_=None):
if table_ is not None:
try:
value = table_.c[k]
except KeyError:
raise HTTPException(status_code=400, detail=f"No feature named '{k}'")
else:
k_op_val_dict = get_level_operator_and_value(k)
value = simplify_value(k_op_val_dict['value'], v["operator"])
if isinstance(value, int):
# for < and > operators in value set definition, need to change the value to be <= or >= equivalent
# in order to compare the value against the same variable cohort definition operator
# For example, >9 should be equivalent to >=10, and if cohort definition is <10, and value set definition
# is >9, >9 rows should be filtered out by the cohort constraint
if k_op_val_dict['operator'] == '<':
value = value - 1
if k_op_val_dict['operator'] == '>':
value = value + 1


# v is a dict with "operator" key; other keys depend on the "operator" value
operations = {
">": lambda: value > v["value"],
"<": lambda: value < v["value"],
">=": lambda: value >= v["value"],
"<=": lambda: value <= v["value"],
"=": lambda: value == v["value"],
"<>": lambda: value != v["value"],
"between": lambda: between(value, v["value_a"], v["value_b"]),
"in": lambda: value.in_(v["values"])
">": lambda: value > simplify_value(v["value"], v["operator"]),
"<": lambda: value < simplify_value(v["value"], v["operator"]),
">=": lambda: value >= simplify_value(v["value"], v["operator"]),
"<=": lambda: value <= simplify_value(v["value"], v["operator"]),
"=": lambda: value == simplify_value(v["value"], v["operator"]),
"<>": lambda: value != simplify_value(v["value"], v["operator"]),
"between": lambda: between(value, simplify_value(v["value_a"], v["operator"]),
simplify_value(v["value_b"], v["operator"])),
"in": lambda: value.in_([simplify_value(val, v["operator"]) for val in v["values"]])
}
return operations[v["operator"]]()

Expand All @@ -87,26 +77,26 @@ def filter_select(s, k, v, table_):
"""Add WHERE clause to selection."""
return s.where(
op_dict(
k, v, table_,
k, v, table_=table_,
)
)


def case_select(table, k, v, table_name=None):
def case_select(table_, k, v):
return func.coalesce(func.sum(case([(
op_dict(
k, v, table,
k, v, table_=table_,
), 1
)], else_=0)), 0)


def case_select2(table, table2, k, v, k2, v2, table_name=None):
def case_select2(table1, table2, k, v, k2, v2):
return func.coalesce(func.sum(case([(and_(
op_dict(
k, v, table,
k, v, table_=table1,
),
op_dict(
k2, v2, table2,
k2, v2, table_=table2,
)
), 1)], else_=0)), 0)

Expand Down Expand Up @@ -848,12 +838,22 @@ def select_feature_count_all_values(
return count


def get_feature_levels(feature, year=None):
def get_feature_levels(feature, year=None, cohort_feat_dict=None):
"""Get feature levels."""
feat_levs = get_value_sets().get(feature, [])
if year and feature == 'year' and int(year) in feat_levs:
# only include the pass-in year in the corresponding year feature level list
feat_levs = [int(year)]
# filter feat_levs by cohort_feat_dict as needed
if cohort_feat_dict:
for k, v in cohort_feat_dict.items():
if k == 'year':
return [yr for yr in feat_levs if op_dict(yr, v)]
elif cohort_feat_dict:
for k, v in cohort_feat_dict.items():
if feature == k:
return [fl for fl in feat_levs if op_dict(fl, v)]

return feat_levs


Expand Down Expand Up @@ -1034,6 +1034,23 @@ def validate_feature_value_in_table_column_for_equal_operator(conn, table_name,
return


def get_level_operator_and_value(input_level):
non_op_idx = 0
if isinstance(input_level, str):
for lev in input_level:
if lev in ['<', '>']:
non_op_idx += 1
else:
break
if non_op_idx == 0:
op = '='
op_val = input_level
else:
op = input_level[:non_op_idx]
op_val = input_level[non_op_idx:]
return {"operator": op, "value": op_val}


def get_operator_and_value(input_levels, feat_name, append_feature_variable=False):
"""
get operator and value from each input level which will be in the format of '>' or '<' followed by a number or
Expand All @@ -1042,23 +1059,11 @@ def get_operator_and_value(input_levels, feat_name, append_feature_variable=Fals
"""
fqs = []
for input_level in input_levels:
non_op_idx = 0
if isinstance(input_level, str):
for lev in input_level:
if lev in ['<', '>']:
non_op_idx += 1
else:
break
if non_op_idx == 0:
op = '='
op_val = input_level
else:
op = input_level[:non_op_idx]
op_val = input_level[non_op_idx:]
op_val_dict = get_level_operator_and_value(input_level)
if append_feature_variable:
fqs.append({feat_name: {"operator": op, "value": op_val}})
fqs.append({feat_name: op_val_dict})
else:
fqs.append({"operator": op, "value": op_val})
fqs.append(op_val_dict)
return fqs


Expand All @@ -1074,23 +1079,26 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
"for computing multivariate associations")

# get feature_constraint list from the first feature variable
feat_constraint_list = get_operator_and_value(get_feature_levels(feature_variables[0], year=year),
feat_constraint_list = get_operator_and_value(get_feature_levels(feature_variables[0], year=year,
cohort_feat_dict=cohort_features),
feature_variables[0], append_feature_variable=True)

index = 1
while index + 2 <= feat_len:
feature_as = [
{
"feature_name": feature_variables[index],
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index], year=year),
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index], year=year,
cohort_feat_dict=cohort_features),
feature_variables[index])
}
]
feature_bs = [
{
"feature_name": feature_variables[index + 1],
"feature_qualifiers": get_operator_and_value(get_feature_levels(feature_variables[index + 1],
year=year),
year=year,
cohort_feat_dict=cohort_features),
feature_variables[index + 1])
}
]
Expand All @@ -1116,7 +1124,8 @@ def compute_multivariate_table(conn, table_name, year, cohort_id, feature_variab
index += 2

if index < feat_len:
feature_qualifiers = get_operator_and_value(get_feature_levels(feature_variables[index], year=year),
feature_qualifiers = get_operator_and_value(get_feature_levels(feature_variables[index], year=year,
cohort_feat_dict=cohort_features),
feature_variables[index])
more_constraint_list = []
for feature_constraint in feat_constraint_list:
Expand Down
Loading