Skip to content

Commit

Permalink
move logger from config to command line arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Jan 21, 2025
1 parent c80d36f commit 538c26d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
3 changes: 0 additions & 3 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ class TrainingConfig:
ManualStateFeatureWeighting, UniformFeatureWeighting
] = dataclasses.field(default_factory=UniformFeatureWeighting)

logger: str = "wandb"
logger_url: str = ""


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
Expand Down
26 changes: 19 additions & 7 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,20 @@ def log_image(self, key, images, step=None):


def _setup_training_logger(config, datastore, args, run_name):
if config.training.logger == "wandb":
if args.logger == "wandb":
logger = pl.loggers.WandbLogger(
project=args.wandb_project,
project=args.logger_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
elif config.training.logger == "mlflow":
url = config.training.logger_url
elif args.logger == "mlflow":
url = args.logger_url
if url is None:
raise ValueError(
"MLFlow logger requires a URL to the MLFlow server"
)
logger = CustomMLFlowLogger(
experiment_name=args.wandb_project,
experiment_name=args.logger_project,
tracking_uri=url,
)
logger.log_hyperparams(
Expand Down Expand Up @@ -251,10 +251,22 @@ def main(input_args=None):

# Logger Settings
parser.add_argument(
"--wandb_project",
"--logger",
type=str,
default="wandb",
help="Logger to use for training (wandb/mlflow) (default: wandb)",
)
parser.add_argument(
"--logger-url",
type=str,
default=None,
help="URL to the logger server (default: None)",
)
parser.add_argument(
"--logger_project",
type=str,
default="neural_lam",
help="Wandb project name (default: neural_lam)",
help="Logger project name, for eg. Wandb (default: neural_lam)",
)
parser.add_argument(
"--val_steps_to_log",
Expand Down

0 comments on commit 538c26d

Please sign in to comment.