-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsweep.py
81 lines (65 loc) · 3.66 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from util.utils import *
from pathlib import Path
import argparse
import wandb
import warnings
warnings.filterwarnings("ignore")
def train():
"""
This function initializes a run with Weights & Biases (wandb), updates the configuration with the model configuration,
constructs a run name based on the sweep parameters, retrains the model with transformed datasets, evaluates the model,
and finally finishes the run.
The function uses the following global variables:
- common_config: A dictionary containing common configuration parameters for the wandb run.
- wandb_config: A dictionary containing the project and entity for the wandb run.
- model_config: A dictionary containing the model configuration parameters.
- sweep_paras: A list of parameters to be included in the sweep.
- Datasets_transformed: Transformed datasets to be used for retraining the model.
- para_transformed: Transformed parameters to be used for model evaluation.
Note: This function does not return anything. The results of the run are logged and managed by wandb.
"""
run = wandb.init()
model_paras = [para for para in run.config.keys() if para != 'transform']
run_name = f'transform_{run.config["transform"]}'
for para in model_paras:
run_name += f'_{para}_{run.config[para]}'
wandb.run.name = run_name
# print(run_name)
wandb.config.update(common_config, allow_val_change=True)
wandb.config.update(model_config, allow_val_change=True)
retrain_transformed_sweep(Datasets_transformed, model_paras)
evaluate('calib', para_transformed, by_group=True, plot_map=True)
run.finish()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Method for sweeping on W&B')
parser.add_argument('-l', metavar='level', type=str, required=True, choices=['cm', 'pgm'])
parser.add_argument('-c', metavar='config', type=str, required=True, help='Path to the configuration directory')
parser.add_argument('-m', metavar='modelname', help='Name of the model to implement')
args = parser.parse_args()
level = args.l
config_path = Path(args.c)
model_name = args.m
transforms = ['raw', 'log', 'normalize', 'standardize']
Datasets_transformed = {}
para_transformed = {}
qslist, Datasets = i_fetch_data(level)
Datasets_transformed, para_transformed = transform_data(Datasets, transforms, level=level, by_group=True)
common_config_path, wandb_config_path, model_config_path, sweep_config_path = get_config_path(config_path)
common_config = get_config_from_path(common_config_path, 'common')
wandb_config = get_config_from_path(wandb_config_path, 'wandb')
for sweep_file in sweep_config_path.iterdir():
if sweep_file.is_file():
# Skip if a specific model name is provided and it doesn't match the file
model_name_from_file = sweep_file.stem
if model_name and model_name != model_name_from_file:
continue
model_file = model_config_path / sweep_file.name
if not model_file.is_file():
raise FileNotFoundError(f'The corresponding model configuration file {model_file} does not exist.')
sweep_config = get_config_from_path(sweep_file, 'sweep')
model_config = get_config_from_path(model_file, 'model')
model = sweep_file.stem.split('_')[-1]
sweep_id = wandb.sweep(sweep_config, project=wandb_config['project'], entity=wandb_config['entity'])
wandb.agent(sweep_id, function=train)
print(f'Finish sweeping over model {sweep_file.stem}')
print('**************************************************************')