Skip to content

Commit

Permalink
Add not-fitted error and fix count_thres issues (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
qubixes authored Sep 14, 2023
1 parent 88dd6b6 commit d15e314
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion regexmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Regex model for fitting and generating structured strings."""

from regexmodel.model import RegexModel
from regexmodel.util import NotFittedError

__all__ = ["RegexModel"]
__all__ = ["RegexModel", "NotFittedError"]
17 changes: 12 additions & 5 deletions regexmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import polars as pl

from regexmodel.regexclass import fit_best_regex_class
from regexmodel.util import sum_log, LOG_LIKE_PER_CHAR
from regexmodel.util import sum_log, LOG_LIKE_PER_CHAR, NotFittedError
from regexmodel.datastructure import Edge, OrNode, RegexNode
# from regexmodel.model import fit_main_branch

Expand All @@ -22,6 +22,8 @@ def _simplify_edge(edge):
return edge
if len(node.edges) == 1:
return node.edges[0]
if node.count == 0:
edge.destination = None
return edge


Expand Down Expand Up @@ -57,11 +59,11 @@ def fit_main_branch(series: pl.Series,

# Add an END edge
n_end_links = int((series == "").sum())
if n_end_links > count_thres:
if n_end_links >= count_thres:
return_node.add_edge(Edge(None, n_end_links))
cur_series = series.set(series == "", None) # type: ignore

while cur_series.drop_nulls().len() > count_thres:
while cur_series.drop_nulls().len() >= count_thres:
result = fit_best_regex_class(cur_series, count_thres, force_merge=force_merge)

# If the regex fails the threshold, stop the search.
Expand All @@ -88,7 +90,7 @@ def fit_main_branch(series: pl.Series,
cur_or_node = OrNode([Edge(None, new_edge.count)], main_edge)

alt_series = result["alt_series"]
if alt_series.drop_nulls().len() > count_thres:
if alt_series.drop_nulls().len() >= count_thres:
opt_series = alt_series.str.extract(r"(^[\S\s]*?)" + main_edge.regex + r"$")
alt_edge = fit_main_branch(opt_series, count_thres, force_merge=force_merge)
if alt_edge.count > 0:
Expand Down Expand Up @@ -168,7 +170,12 @@ def fit(cls, values: Union[Iterable, Sequence], count_thres: int = 3,
force_merge = False
else:
force_merge = True
return cls(fit_main_branch(values, count_thres=count_thres, force_merge=force_merge))
regex_edge = fit_main_branch(values, count_thres=count_thres, force_merge=force_merge)
if regex_edge.count == 0:
raise NotFittedError(f"Could not fit regex on values, with count_thres={count_thres}"
f" and method='{method}'. Try lowering count_thres or using "
"method='fast'.")
return cls(regex_edge)

@classmethod
def from_regex(cls, regex_str: str):
Expand Down
2 changes: 1 addition & 1 deletion regexmodel/regexclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_class_stat(series: pl.Series, count_thres: int) -> list:
score_list: list[tuple[BaseRegex, float, pl.Series]] = []
for rclass in [UpperRegex, LowerRegex, DigitRegex]:
cur_class_stat = score_single(rclass.base_regex, series, count_thres, rclass.n_possible)
if cur_class_stat[0] > 0 and cur_class_stat[1].drop_nulls().len() > count_thres:
if cur_class_stat[0] > 0 and cur_class_stat[1].drop_nulls().len() >= count_thres:
score_list.append((rclass(), *cur_class_stat))
score_list.extend(LiteralRegex.get_candidates(series, count_thres))
return sorted(score_list, key=lambda res: -res[1])
Expand Down
4 changes: 4 additions & 0 deletions regexmodel/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ def sum_log(log_likes):
rel_probs = np.exp(log_likes-max_log)
sum_prob = np.sum(rel_probs)
return np.log(sum_prob) + max_log


class NotFittedError(ValueError):
"""Signal that the regex could not be fitted."""

0 comments on commit d15e314

Please sign in to comment.