Skip to content

Commit

Permalink
move custom_logger to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Jan 22, 2025
1 parent f7f90c4 commit 3187dc3
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 87 deletions.
61 changes: 61 additions & 0 deletions neural_lam/custom_loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Standard library
import sys

# Third-party
# third-party
import mlflow
import mlflow.pytorch
import pytorch_lightning as pl
from loguru import logger


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""
Custom MLFlow logger that adds functionality not present in the default
"""

def __init__(self, experiment_name, tracking_uri, run_name):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)

mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.set_tag("mlflow.runName", run_name)
mlflow.log_param("run_id", self.run_id)

@property
def save_dir(self):
"""
Returns the directory where the MLFlow artifacts are saved
"""
return "mlruns"

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow
key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
"""
# Third-party
import botocore
from PIL import Image

if step is not None:
key = f"{key}_{step}"

# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but is buggy
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
try:
mlflow.log_image(img, f"{key}.png")
except botocore.exceptions.NoCredentialsError:
logger.error("Error logging image\nSet AWS credentials")
sys.exit(1)
89 changes: 2 additions & 87 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
# Standard library
import json
import os
import random
import sys
import time
from argparse import ArgumentParser

# Third-party
import mlflow

# for logging the model:
import mlflow.pytorch
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
Expand All @@ -29,85 +24,6 @@
}


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""
Custom MLFlow logger that adds functionality not present in the default
"""

def __init__(self, experiment_name, tracking_uri, run_name):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
)

mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
mlflow.set_tag("mlflow.runName", run_name)
mlflow.log_param("run_id", self.run_id)

@property
def save_dir(self):
"""
Returns the directory where the MLFlow artifacts are saved
"""
return "mlruns"

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow
key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
"""
# Third-party
import botocore
from PIL import Image

if step is not None:
key = f"{key}_{step}"

# Need to save the image to a temporary file, then log that file
# mlflow.log_image, should do this automatically, but is buggy
temporary_image = f"{key}.png"
images[0].savefig(temporary_image)

img = Image.open(temporary_image)
try:
mlflow.log_image(img, f"{key}.png")
except botocore.exceptions.NoCredentialsError:
logger.error("Error logging image\nSet AWS credentials")
sys.exit(1)


@pl.utilities.rank_zero.rank_zero_only
def _setup_training_logger(config, datastore, args, run_name):

if args.logger == "wandb":
logger = pl.loggers.WandbLogger(
project=args.logger_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
elif args.logger == "mlflow":
url = os.getenv("MLFLOW_TRACKING_URI")
if url is None:
raise ValueError(
"MLFlow logger requires setting MLFLOW_TRACKING_URI in env."
)
logger = CustomMLFlowLogger(
experiment_name=args.logger_project,
tracking_uri=url,
run_name=run_name,
)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)

return logger


@logger.catch
def main(input_args=None):
"""Main function for training and evaluating models."""
Expand Down Expand Up @@ -354,9 +270,8 @@ def main(input_args=None):
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
)

# Only initialise logger on rank 0
training_logger = _setup_training_logger(
config=config, datastore=datastore, args=args, run_name=run_name
training_logger = utils.setup_training_logger(
datastore=datastore, args=args, run_name=run_name
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
Expand Down
49 changes: 49 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import warnings

# Third-party
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import MLFlowLogger, WandbLogger
from torch import nn
from tueplots import bundles, figsizes

# Local
from .custom_loggers import CustomMLFlowLogger


class BufferList(nn.Module):
"""
Expand Down Expand Up @@ -251,3 +255,48 @@ def init_training_logger_metrics(training_logger, val_steps):
"Only WandbLogger & MLFlowLogger is supported for tracking metrics.\
Experiment results will only go to stdout."
)


@pl.utilities.rank_zero.rank_zero_only
def setup_training_logger(datastore, args, run_name):
"""
Parameters
----------
datastore : Datastore
Datastore object.
args : argparse.Namespace
Arguments from command line.
run_name : str
Name of the run.
Returns
-------
logger : pytorch_lightning.loggers.base
Logger object.
"""

if args.logger == "wandb":
logger = pl.loggers.WandbLogger(
project=args.logger_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
elif args.logger == "mlflow":
url = os.getenv("MLFLOW_TRACKING_URI")
if url is None:
raise ValueError(
"MLFlow logger requires setting MLFLOW_TRACKING_URI in env."
)
logger = CustomMLFlowLogger(
experiment_name=args.logger_project,
tracking_uri=url,
run_name=run_name,
)
logger.log_hyperparams(
dict(training=vars(args), datastore=datastore._config)
)

return logger

0 comments on commit 3187dc3

Please sign in to comment.