Skip to content

Commit

Permalink
Merging in recent changes from main
Browse files Browse the repository at this point in the history
Merge branch 'main' into adr_016_drift_logging
  • Loading branch information
jimdale committed Oct 29, 2024
2 parents 7509f84 + 9724c9a commit 7feace1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
34 changes: 16 additions & 18 deletions common_utils/model_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ModelPath:
visualization (Path): The directory for visualization scripts.
_sys_paths (list): A list of system paths.
common_querysets (Path): The directory for common querysets.
_queryset_path (Path): The path to the queryset script.
queryset_path (Path): The path to the queryset script.
_queryset (module): The imported queryset module.
scripts (list): A list of script paths.
_ignore_attributes (list): A list of paths to ignore.
Expand Down Expand Up @@ -96,7 +96,7 @@ class ModelPath:
"visualization",
"_sys_paths",
"common_querysets",
"_queryset_path",
"queryset_path",
"scripts",
"meta_tools",
)
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
"_validate",
"models",
"_sys_paths",
"_queryset_path",
"queryset_path",
"_queryset",
"_ignore_attributes",
"target",
Expand Down Expand Up @@ -346,7 +346,6 @@ def _handle_global_cache(self) -> None:
"""
try:
from global_cache import GlobalCache

cached_instance = GlobalCache[self._instance_hash]
if cached_instance and not self._force_cache_overwrite:
logger.info(
Expand All @@ -357,7 +356,7 @@ def _handle_global_cache(self) -> None:
logger.error(
f"Error adding model {self.model_name} to cache: {e}. Initializing new ModelPath instance."
)

def _write_to_global_cache(self) -> None:
"""
Writes the current model instance to the global cache if it doesn't exist.
Expand Down Expand Up @@ -407,7 +406,7 @@ def _initialize_directories(self) -> None:
self._sys_paths = None
if self.common_querysets not in sys.path:
sys.path.insert(0, str(self.common_querysets))
self._queryset_path = self.common_querysets / f"queryset_{self.model_name}.py"
self.queryset_path = self.common_querysets / f"queryset_{self.model_name}.py"
self._queryset = None

def _initialize_scripts(self) -> None:
Expand Down Expand Up @@ -437,7 +436,8 @@ def _initialize_scripts(self) -> None:
self._build_absolute_directory(
Path("src/offline_evaluation/evaluate_model.py")
),
self._build_absolute_directory(Path("src/training/train_ensemble.py")),
self._build_absolute_directory(Path(f"src/training/train_{self.target}.py")),
self.common_querysets / f"queryset_{self.model_name}.py"
]

def _is_path(self, path_input: Union[str, Path]) -> bool:
Expand Down Expand Up @@ -476,18 +476,18 @@ def get_queryset(self) -> Optional[Dict[str, str]]:
error = f"Common queryset directory {self.common_querysets} does not exist. Please create it first using `make_new_scripts.py` or set validate to `False`."
logger.error(error)
raise FileNotFoundError(error)
elif self._validate and self._check_if_dir_exists(self._queryset_path):
elif self._validate and self._check_if_dir_exists(self.queryset_path):
try:
self._queryset = importlib.import_module(self._queryset_path.stem)
self._queryset = importlib.import_module(self.queryset_path.stem)
except Exception as e:
logger.error(f"Error importing queryset: {e}")
self._queryset = None
else:
logger.info(f"Queryset {self._queryset_path} imported successfully.")
logger.info(f"Queryset {self.queryset_path} imported successfully.")
return self._queryset.generate() if self._queryset else None
else:
logger.warning(
f"Queryset {self._queryset_path} does not exist. Continuing..."
f"Queryset {self.queryset_path} does not exist. Continuing..."
)
return None

Expand Down Expand Up @@ -677,7 +677,7 @@ def get_directories(self) -> Dict[str, Optional[str]]:
# "_validate",
# "models",
# "_sys_paths",
# "_queryset_path",
# "queryset_path",
# "_queryset",
# "_ignore_attributes",
# "target",
Expand All @@ -699,7 +699,7 @@ def get_directories(self) -> Dict[str, Optional[str]]:
"templates",
"_sys_paths",
"_queryset",
"_queryset_path",
"queryset_path",
"_ignore_attributes",
"target",
"_force_cache_overwrite",
Expand Down Expand Up @@ -743,9 +743,7 @@ def get_scripts(self) -> Dict[str, Optional[str]]:

# if __name__ == "__main__":
# model_path = ModelPath("taco_cat", validate=False)
# model_path.view_directories()
# model_path.view_scripts()
# model_path.get_directories()
# model_path.get_scripts()
# print(model_path.get_scripts())
# print(model_path.get_directories())
# print(model_path.get_queryset())
# print(ModelPath.get_common_configs())
# print(model_path.queryset_path)
2 changes: 1 addition & 1 deletion common_utils/utils_logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging


def setup_logging(log_file: str, log_level=logging.INFO):
def setup_logging(log_file: str, log_level=logging.INFO) -> logging.Logger:
"""
Sets up logging to both a specified file and the terminal (console).
Expand Down
4 changes: 3 additions & 1 deletion meta_tools/tests/test_model_scaffold_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
else:
raise ValueError("The 'views_pipeline' directory was not found in the provided path.")

print(sys.path)

import os
import pytest
from unittest.mock import patch, MagicMock
import tempfile
import shutil
from meta_tools.model_scaffold_builder import ModelScaffoldBuilder
from model_scaffold_builder import ModelScaffoldBuilder


@pytest.fixture
Expand Down

0 comments on commit 7feace1

Please sign in to comment.