Skip to content

Commit

Permalink
Merge pull request #146 from prio-data/update_models_main_py
Browse files Browse the repository at this point in the history
update templates and main files
  • Loading branch information
Polichinel authored Oct 31, 2024
2 parents 923cb88 + 315c3c9 commit ef17f48
Show file tree
Hide file tree
Showing 17 changed files with 118 additions and 57 deletions.
10 changes: 9 additions & 1 deletion ensembles/cruel_summer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.ensemble_path import EnsemblePath
from common_utils.global_cache import GlobalCache
model_name = EnsemblePath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion ensembles/white_mustang/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.ensemble_path import EnsemblePath
from common_utils.global_cache import GlobalCache
model_name = EnsemblePath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
39 changes: 15 additions & 24 deletions meta_tools/templates/ensemble/template_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,17 @@

def generate(script_dir: Path) -> bool:
"""
Generates a Python script that sets up and executes model runs with Weights & Biases (WandB) integration.
This function creates a script that imports necessary modules, sets up project paths, and defines the
main execution logic for running either a single model run or a sweep of model configurations. The
generated script includes command-line argument parsing, validation, and runtime logging.
Generates a script that sets up the project paths, parses command-line arguments,
sets up logging, and executes a single model run.
Parameters:
script_dir (Path):
The directory where the generated Python script will be saved. This should be a valid writable
path that exists within the project structure.
The directory where the generated script will be saved.
This should be a valid writable path.
Returns:
bool:
True if the script was successfully written to the specified directory, False otherwise.
The generated script includes the following features:
- Imports required libraries and sets up the path to include the `common_utils` module.
- Initializes project paths using the `setup_project_paths` function.
- Parses command-line arguments with `parse_args`.
- Validates arguments to ensure correctness with `validate_arguments`.
- Logs into Weights & Biases using `wandb.login()`.
- Executes a model run based on the provided command-line flags, either initiating a sweep or a single run.
- Calculates and prints the runtime of the execution in minutes.
Note:
- Ensure that the `common_utils` module and all other imported modules are accessible from the
specified script directory.
- The generated script is designed to be executed as a standalone Python script.
True if the script was written and compiled successfully, False otherwise.
"""
code = """import wandb
import sys
Expand All @@ -49,8 +32,16 @@ def generate(script_dir: Path) -> bool:
from execute_model_runs import execute_single_run
warnings.filterwarnings("ignore")
logger = setup_logging('run.log')
try:
from common_utils.ensemble_path import EnsemblePath
from common_utils.global_cache import GlobalCache
model_name = EnsemblePath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model separated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")
if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions meta_tools/templates/model/template_config_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def generate(script_dir: Path, model_name: str, model_algorithm: str) -> bool:
\"""
meta_config = {{
"name": "{model_name}", # Eg. happy_kitten
"algorithm": "{model_algorithm}," # Eg. "LSTM", "CNN", "Transformer"
"name": "{model_name}",
"algorithm": "{model_algorithm}",
# Uncomment and modify the following lines as needed for additional metadata:
# "target(S)": ["ln_sb_best", "ln_ns_best", "ln_os_best", "ln_sb_best_binarized", "ln_ns_best_binarized", "ln_os_best_binarized"],
# "depvar": "ln_ged_sb_dep",
# "queryset": "escwa001_cflong",
# "level": "pgm",
# "creator": "Your name here"
Expand Down
2 changes: 1 addition & 1 deletion meta_tools/templates/model/template_evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def generate(script_dir: Path) -> bool:
import logging
from model_path import ModelPath
from utils_log_files import create_log_file, read_log_file
from utils_outputs import save_model_outputs, save_predictions
from utils_save_outputs import save_model_outputs, save_predictions
from utils_run import get_standardized_df
from utils_artifacts import get_latest_model_artifact
from utils_evaluation_metrics import generate_metric_dict
Expand Down
4 changes: 2 additions & 2 deletions meta_tools/templates/model/template_evaluate_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def evaluate_sweep(config, stepshift_model):
run_type = config["run_type"]
steps = config["steps"]
df_viewser = pd.read_pickle(path_raw / f"{{{{run_type}}}}_viewser_df.pkl")
df_viewser = pd.read_pickle(path_raw / f"{{run_type}}_viewser_df.pkl")
df = stepshift_model.predict(run_type, df_viewser)
df = get_standardized_df(df, config)
# Temporarily keep this because the metric to minimize is MSE
pred_cols = [f"step_pred_{{{{str(i)}}}}" for i in steps]
pred_cols = [f"step_pred_{{str(i)}}" for i in steps]
df["mse"] = df.apply(lambda row: mean_squared_error([row[config["depvar"]]] * 36,
[row[col] for col in pred_cols]), axis=1)
Expand Down
12 changes: 6 additions & 6 deletions meta_tools/templates/model/template_generate_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def generate(script_dir: Path) -> bool:
from model_path import ModelPath
from utils_log_files import create_log_file, read_log_file
from utils_run import get_standardized_df
from utils_outputs import save_predictions
from utils_save_outputs import save_predictions
from utils_artifacts import get_latest_model_artifact
logger = logging.getLogger(__name__)
Expand All @@ -38,28 +38,28 @@ def forecast_model_artifact(config, artifact_name):
# if an artifact name is provided through the CLI, use it.
# Otherwise, get the latest model artifact based on the run type
if artifact_name:
logger.info(f"Using (non-default) artifact: {{{{artifact_name}}}}")
logger.info(f"Using (non-default) artifact: {{artifact_name}}")
if not artifact_name.endswith(".pkl"):
artifact_name += ".pkl"
path_artifact = path_artifacts / artifact_name
else:
# use the latest model artifact based on the run type
logger.info(f"Using latest (default) run type ({{{{run_type}}}}) specific artifact")
logger.info(f"Using latest (default) run type ({{run_type}}) specific artifact")
path_artifact = get_latest_model_artifact(path_artifacts, run_type)
config["timestamp"] = path_artifact.stem[-15:]
df_viewser = pd.read_pickle(path_raw / f"{{{{run_type}}}}_viewser_df.pkl")
df_viewser = pd.read_pickle(path_raw / f"{{run_type}}_viewser_df.pkl")
try:
stepshift_model = pd.read_pickle(path_artifact)
except FileNotFoundError:
logger.exception(f"Model artifact not found at {{{{path_artifact}}}}")
logger.exception(f"Model artifact not found at {{path_artifact}}")
df_predictions = stepshift_model.predict(run_type, df_viewser)
df_predictions = get_standardized_df(df_predictions, config)
data_generation_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
date_fetch_timestamp = read_log_file(path_raw / f"{{{{run_type}}}}_data_fetch_log.txt").get("Data Fetch Timestamp", None)
date_fetch_timestamp = read_log_file(path_raw / f"{{run_type}}_data_fetch_log.txt").get("Data Fetch Timestamp", None)
save_predictions(df_predictions, path_generated, config)
create_log_file(path_generated, config, config["timestamp"], data_generation_timestamp, date_fetch_timestamp)
Expand Down
13 changes: 8 additions & 5 deletions meta_tools/templates/model/template_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,18 @@ def generate(script_dir: Path) -> bool:
try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
GlobalCache["current_model"] = ModelPath.get_model_name_from_path(Path(__file__))
except Exception:
pass
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")
if __name__ == "__main__":
wandb.login()
args = parse_args()
validate_arguments(args)
Expand All @@ -69,4 +72,4 @@ def generate(script_dir: Path) -> bool:
else:
execute_single_run(args)
"""
return utils_script_gen.save_script(script_dir, code)
return utils_script_gen.save_script(script_dir, code)
6 changes: 3 additions & 3 deletions meta_tools/templates/model/template_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def train_model_artifact(config):
path_generated = model_path.data_generated
path_artifacts = model_path.artifacts
run_type = config["run_type"]
df_viewser = pd.read_pickle(path_raw / f"{{{{run_type}}}}_viewser_df.pkl")
df_viewser = pd.read_pickle(path_raw / f"{{run_type}}_viewser_df.pkl")
stepshift_model = stepshift_training(config, run_type, df_viewser)
if not config["sweep"]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"{{{{run_type}}}}_model_{{{{timestamp}}}}.pkl"
model_filename = f"{{run_type}}_model_{{timestamp}}.pkl"
stepshift_model.save(path_artifacts / model_filename)
date_fetch_timestamp = read_log_file(path_raw / f"{{{{run_type}}}}_data_fetch_log.txt").get("Data Fetch Timestamp", None)
date_fetch_timestamp = read_log_file(path_raw / f"{{run_type}}_data_fetch_log.txt").get("Data Fetch Timestamp", None)
create_log_file(path_generated, config, timestamp, None, date_fetch_timestamp)
return stepshift_model
Expand Down
4 changes: 2 additions & 2 deletions meta_tools/templates/model/template_utils_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_model(config, partitioner_dict):
Get the model based on the algorithm specified in the config
\"""
if config["algorithm"] == "HurdleRegression":
if config["algorithm"] == "HurdleModel":
model = HurdleModel(config, partitioner_dict)
else:
config["model_reg"] = config["algorithm"]
Expand All @@ -49,7 +49,7 @@ def get_standardized_df(df, config):
if run_type in ["calibration", "testing"]:
cols = [depvar] + df.forecasts.prediction_columns
elif run_type == "forecasting":
cols = [f"step_pred_{{{{i}}}}" for i in steps]
cols = [f"step_pred_{{i}}" for i in steps]
df = df.replace([np.inf, -np.inf], 0)[cols]
df = df.mask(df < 0, 0)
return df
Expand Down
10 changes: 9 additions & 1 deletion models/blank_space/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion models/electric_relaxation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging('run.log')


Expand Down
9 changes: 6 additions & 3 deletions models/lavender_haze/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
GlobalCache["current_model"] = ModelPath.get_model_name_from_path(Path(__file__))
except Exception:
pass
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion models/old_money/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion models/orange_pasta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion models/wildest_dream/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down
10 changes: 9 additions & 1 deletion models/yellow_pikachu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from execute_model_runs import execute_sweep_run, execute_single_run

warnings.filterwarnings("ignore")

try:
from common_utils.model_path import ModelPath
from common_utils.global_cache import GlobalCache
model_name = ModelPath.get_model_name_from_path(PATH)
GlobalCache["current_model"] = model_name
except ImportError as e:
warnings.warn(f"ImportError: {e}. Some functionalities (model seperated log files) may not work properly.", ImportWarning)
except Exception as e:
warnings.warn(f"An unexpected error occurred: {e}.", RuntimeWarning)
logger = setup_logging("run.log")


Expand Down

0 comments on commit ef17f48

Please sign in to comment.