diff --git a/regexmodel/__init__.py b/regexmodel/__init__.py index d3ebe1c..f6fc8fc 100644 --- a/regexmodel/__init__.py +++ b/regexmodel/__init__.py @@ -1,5 +1,6 @@ """Regex model for fitting and generating structured strings.""" from regexmodel.model import RegexModel +from regexmodel.util import NotFittedError -__all__ = ["RegexModel"] +__all__ = ["RegexModel", "NotFittedError"] diff --git a/regexmodel/model.py b/regexmodel/model.py index 5ac0378..0b91f81 100644 --- a/regexmodel/model.py +++ b/regexmodel/model.py @@ -7,7 +7,7 @@ import polars as pl from regexmodel.regexclass import fit_best_regex_class -from regexmodel.util import sum_log, LOG_LIKE_PER_CHAR +from regexmodel.util import sum_log, LOG_LIKE_PER_CHAR, NotFittedError from regexmodel.datastructure import Edge, OrNode, RegexNode # from regexmodel.model import fit_main_branch @@ -22,6 +22,8 @@ def _simplify_edge(edge): return edge if len(node.edges) == 1: return node.edges[0] + if node.count == 0: + edge.destination = None return edge @@ -57,11 +59,11 @@ def fit_main_branch(series: pl.Series, # Add an END edge n_end_links = int((series == "").sum()) - if n_end_links > count_thres: + if n_end_links >= count_thres: return_node.add_edge(Edge(None, n_end_links)) cur_series = series.set(series == "", None) # type: ignore - while cur_series.drop_nulls().len() > count_thres: + while cur_series.drop_nulls().len() >= count_thres: result = fit_best_regex_class(cur_series, count_thres, force_merge=force_merge) # If the regex fails the threshold, stop the search. @@ -88,7 +90,7 @@ def fit_main_branch(series: pl.Series, cur_or_node = OrNode([Edge(None, new_edge.count)], main_edge) alt_series = result["alt_series"] - if alt_series.drop_nulls().len() > count_thres: + if alt_series.drop_nulls().len() >= count_thres: opt_series = alt_series.str.extract(r"(^[\S\s]*?)" + main_edge.regex + r"$") alt_edge = fit_main_branch(opt_series, count_thres, force_merge=force_merge) if alt_edge.count > 0: @@ -168,7 +170,12 @@ def fit(cls, values: Union[Iterable, Sequence], count_thres: int = 3, force_merge = False else: force_merge = True - return cls(fit_main_branch(values, count_thres=count_thres, force_merge=force_merge)) + regex_edge = fit_main_branch(values, count_thres=count_thres, force_merge=force_merge) + if regex_edge.count == 0: + raise NotFittedError(f"Could not fit regex on values, with count_thres={count_thres}" + f" and method='{method}'. Try lowering count_thres or using " + "method='fast'.") + return cls(regex_edge) @classmethod def from_regex(cls, regex_str: str): diff --git a/regexmodel/regexclass.py b/regexmodel/regexclass.py index f7def30..cfea010 100644 --- a/regexmodel/regexclass.py +++ b/regexmodel/regexclass.py @@ -287,7 +287,7 @@ def get_class_stat(series: pl.Series, count_thres: int) -> list: score_list: list[tuple[BaseRegex, float, pl.Series]] = [] for rclass in [UpperRegex, LowerRegex, DigitRegex]: cur_class_stat = score_single(rclass.base_regex, series, count_thres, rclass.n_possible) - if cur_class_stat[0] > 0 and cur_class_stat[1].drop_nulls().len() > count_thres: + if cur_class_stat[0] > 0 and cur_class_stat[1].drop_nulls().len() >= count_thres: score_list.append((rclass(), *cur_class_stat)) score_list.extend(LiteralRegex.get_candidates(series, count_thres)) return sorted(score_list, key=lambda res: -res[1]) diff --git a/regexmodel/util.py b/regexmodel/util.py index f70afa6..8b09384 100644 --- a/regexmodel/util.py +++ b/regexmodel/util.py @@ -13,3 +13,7 @@ def sum_log(log_likes): rel_probs = np.exp(log_likes-max_log) sum_prob = np.sum(rel_probs) return np.log(sum_prob) + max_log + + +class NotFittedError(ValueError): + """Signal that the regex could not be fitted."""