diff --git a/src/ms3/utils/functions.py b/src/ms3/utils/functions.py index 39eb8bf6..675502aa 100644 --- a/src/ms3/utils/functions.py +++ b/src/ms3/utils/functions.py @@ -303,6 +303,39 @@ def check_labels( return pd.concat([df.loc[select_wrong, cols], res], axis=1) +def check_phrase_annotations(df: pd.DataFrame, column: str, logger=None) -> bool: + """""" + if logger is None: + logger = module_logger + elif isinstance(logger, str): + logger = get_logger(logger) + p_col = df[column] + opening = p_col.fillna("").str.count("{") + closing = p_col.fillna("").str.count("}") + if "mn_playthrough" in df.columns: + position_col = "mn_playthrough" + else: + logger.info( + "Column 'mn_playthrough' is missing, so my assessment of the phrase annotations might be wrong." + ) + position_col = "mn" + columns = [position_col, column] + if opening.sum() != closing.sum(): + o = df.loc[(opening > 0), columns] + c = df.loc[(closing > 0), columns] + compare = pd.concat( + [o.reset_index(drop=True), c.reset_index(drop=True)], axis=1 + ) + if "mn" in compare: + compare = compare.astype({"mn": "Int64"}) + logger.warning( + f"Phrase beginning and endings don't match:\n{compare.to_string(index=False)}", + extra={"message_id": (16,)}, + ) + return False + return True + + def color2rgba(c): """Pass a RGB or RGBA tuple, HTML color or name to convert it to RGBA""" if isinstance(c, tuple): @@ -1752,6 +1785,54 @@ def _fifths2str(fifths: int, steps: Collection[str], inverted: bool = False) -> return acc + steps[fifths % 7] +def get_name_of_highest_version_tag( + repo: git.Repo, +) -> Optional[str]: + descending_tags = repo.git.tag(l=True, sort="-v:refname") + latest_version = descending_tags.split("\n")[0] + if latest_version: + return latest_version + + +@cache +def get_git_commit( + repo_path: str, git_revision: Optional[str], logger=None +) -> Optional[git.Commit]: + """Returns the git commit object for the given revision. + + Args: + repo_path: + git_revision: + Any specifier that git understands (branch, tag, commit hash, "HEAD", etc.). In addition, + "LATEST_VERSION" can be passed to get the tag with the highest version number. + logger: + + Returns: + git.Commit object that corresponds to the given revision specifier. + """ + if logger is None: + logger = module_logger + elif isinstance(logger, str): + logger = get_logger(logger) + try: + repo = git.Repo(repo_path, search_parent_directories=True) + except Exception as e: + logger.error(f"{repo_path} is not an existing git repository: {e}") + return + if git_revision == "LATEST_VERSION": + git_revision = get_name_of_highest_version_tag(repo) + if git_revision is None: + logger.error( + "Could not find the latest version tag, falling back to current HEAD." + ) + try: + return repo.commit(git_revision) + except BadName: + logger.error( + f"{git_revision} does not resolve to a commit for repo {os.path.basename(repo_path)}." + ) + + def get_git_repo( directory: str | Path, search_parent_directories: bool = True, @@ -6510,54 +6591,6 @@ def string2identifier(s: str, remove_leading_underscore: bool = True) -> str: return s -def get_name_of_highest_version_tag( - repo: git.Repo, -) -> Optional[str]: - descending_tags = repo.git.tag(l=True, sort="-v:refname") - latest_version = descending_tags.split("\n")[0] - if latest_version: - return latest_version - - -@cache -def get_git_commit( - repo_path: str, git_revision: Optional[str], logger=None -) -> Optional[git.Commit]: - """Returns the git commit object for the given revision. - - Args: - repo_path: - git_revision: - Any specifier that git understands (branch, tag, commit hash, "HEAD", etc.). In addition, - "LATEST_VERSION" can be passed to get the tag with the highest version number. - logger: - - Returns: - git.Commit object that corresponds to the given revision specifier. - """ - if logger is None: - logger = module_logger - elif isinstance(logger, str): - logger = get_logger(logger) - try: - repo = git.Repo(repo_path, search_parent_directories=True) - except Exception as e: - logger.error(f"{repo_path} is not an existing git repository: {e}") - return - if git_revision == "LATEST_VERSION": - git_revision = get_name_of_highest_version_tag(repo) - if git_revision is None: - logger.error( - "Could not find the latest version tag, falling back to current HEAD." - ) - try: - return repo.commit(git_revision) - except BadName: - logger.error( - f"{git_revision} does not resolve to a commit for repo {os.path.basename(repo_path)}." - ) - - @cache def resolve_git_revision( repo_path: str, git_revision: Optional[str], logger=None @@ -6634,39 +6667,6 @@ def parse_tsv_file_at_git_revision( return new_file, parsed -def check_phrase_annotations(df: pd.DataFrame, column: str, logger=None) -> bool: - """""" - if logger is None: - logger = module_logger - elif isinstance(logger, str): - logger = get_logger(logger) - p_col = df[column] - opening = p_col.fillna("").str.count("{") - closing = p_col.fillna("").str.count("}") - if "mn_playthrough" in df.columns: - position_col = "mn_playthrough" - else: - logger.info( - "Column 'mn_playthrough' is missing, so my assessment of the phrase annotations might be wrong." - ) - position_col = "mn" - columns = [position_col, column] - if opening.sum() != closing.sum(): - o = df.loc[(opening > 0), columns] - c = df.loc[(closing > 0), columns] - compare = pd.concat( - [o.reset_index(drop=True), c.reset_index(drop=True)], axis=1 - ) - if "mn" in compare: - compare = compare.astype({"mn": "Int64"}) - logger.warning( - f"Phrase beginning and endings don't match:\n{compare.to_string(index=False)}", - extra={"message_id": (16,)}, - ) - return False - return True - - def write_messages_to_file_or_remove( warnings_file: str, warnings: List[str], header: str, logger=None ) -> bool: