Skip to content

Commit

Permalink
Add carbon emissions for training (stfc#265)
Browse files Browse the repository at this point in the history
* Add carbon tracking to training

* Add missing tmp_path in singlepoint tests

* Move training print to logging
  • Loading branch information
ElliottKasoar authored Aug 14, 2024
1 parent 1be2f3f commit e0de4a6
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 11 deletions.
22 changes: 21 additions & 1 deletion janus_core/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typer import Option, Typer
import yaml

from janus_core.cli.types import LogPath, Summary
from janus_core.cli.utils import carbon_summary, end_summary, start_summary
from janus_core.helpers.train import train as run_train

app = Typer()
Expand All @@ -19,6 +21,8 @@ def train(
fine_tune: Annotated[
bool, Option(help="Whether to fine-tune a foundational model.")
] = False,
log: LogPath = "train.log",
summary: Summary = "train_summary.yml",
):
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Expand All @@ -29,6 +33,11 @@ def train(
Configuration file to pass to MLIP CLI.
fine_tune : bool
Whether to fine-tune a foundational model. Default is False.
log : Optional[Path]
Path to write logs to. Default is "train.log".
summary : Path
Path to save summary of inputs and start/end time. Default is
train_summary.yml.
"""
with open(mlip_config, encoding="utf8") as config_file:
config = yaml.safe_load(config_file)
Expand All @@ -52,4 +61,15 @@ def train(
elif "foundation_model" in config:
raise ValueError("Please include the `--fine-tune` option for fine-tuning")

run_train(mlip_config)
inputs = {"mlip_config": str(mlip_config), "fine_tune": fine_tune}

# Save summary information before training begins
start_summary(command="train", summary=summary, inputs=inputs)

# Run training
run_train(mlip_config, log_kwargs={"filename": log, "filemode": "w"})

carbon_summary(summary=summary, log=log)

# Save time after training has finished
end_summary(summary)
31 changes: 27 additions & 4 deletions janus_core/helpers/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Train MLIP."""

from pathlib import Path
from typing import Optional
from typing import Any, Optional

try:
from mace.cli.run_train import run as run_train
Expand All @@ -11,6 +11,8 @@
import yaml

from janus_core.helpers.janus_types import PathLike
from janus_core.helpers.log import config_logger, config_tracker
from janus_core.helpers.utils import none_to_dict


def check_files_exist(config: dict, req_file_keys: list[PathLike]) -> None:
Expand All @@ -37,7 +39,10 @@ def check_files_exist(config: dict, req_file_keys: list[PathLike]) -> None:


def train(
mlip_config: PathLike, req_file_keys: Optional[list[PathLike]] = None
mlip_config: PathLike,
req_file_keys: Optional[list[PathLike]] = None,
log_kwargs: Optional[dict[str, Any]] = None,
tracker_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Run training for MLIP by passing a configuration file to the MLIP's CLI.
Expand All @@ -52,7 +57,13 @@ def train(
req_file_keys : Optional[list[PathLike]]
List of files that must exist if defined in the configuration file.
Default is ["train_file", "test_file", "valid_file", "statistics_file"].
log_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to `config_logger`. Default is {}.
tracker_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
(log_kwargs, tracker_kwargs) = none_to_dict((log_kwargs, tracker_kwargs))

if req_file_keys is None:
req_file_keys = ["train_file", "test_file", "valid_file", "statistics_file"]

Expand All @@ -61,9 +72,21 @@ def train(
options = yaml.safe_load(file)
check_files_exist(options, req_file_keys)

if "foundation_model" in options:
print(f"Fine tuning model: {options['foundation_model']}")
# Configure logging
log_kwargs.setdefault("name", __name__)
logger = config_logger(**log_kwargs)
tracker = config_tracker(logger, **tracker_kwargs)

if logger and "foundation_model" in options:
logger.info("Fine tuning model: %s", options["foundation_model"])

# Path must be passed as a string
mlip_args = mace_parser().parse_args(["--config", str(mlip_config)])
if logger:
logger.info("Starting training")
tracker.start_task("Training")
run_train(mlip_args)
if logger:
logger.info("Training complete")
tracker.stop_task()
tracker.stop()
12 changes: 12 additions & 0 deletions tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def test_invalid_config():
def test_write_kwargs(tmp_path):
"""Test setting invalidate_calc and write_results via write_kwargs."""
results_path = tmp_path / "NaCl-results.extxyz"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
Expand All @@ -293,6 +295,10 @@ def test_write_kwargs(tmp_path):
"{'invalidate_calc': False}",
"--out",
results_path,
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0
Expand All @@ -306,6 +312,8 @@ def test_write_kwargs(tmp_path):
def test_write_cif(tmp_path):
"""Test writing out a cif file."""
results_path = tmp_path / "NaCl-results.cif"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
Expand All @@ -317,6 +325,10 @@ def test_write_cif(tmp_path):
"{'invalidate_calc': False, 'write_results': True}",
"--out",
results_path,
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0
Expand Down
88 changes: 82 additions & 6 deletions tests/test_train_cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Test train commandline interface."""

import logging
from pathlib import Path
import shutil

from typer.testing import CliRunner
import yaml

from janus_core.cli.janus import app
from tests.utils import strip_ansi_codes
from tests.utils import assert_log_contains, strip_ansi_codes

DATA_PATH = Path(__file__).parent / "data"
MODEL_PATH = Path(__file__).parent / "models"
Expand Down Expand Up @@ -62,6 +63,8 @@ def test_help():

def test_train(tmp_path):
"""Test MLIP training."""
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"
model = "test.model"
compiled_model = "test_compiled.model"
logs_path = "logs"
Expand All @@ -82,6 +85,10 @@ def test_train(tmp_path):
"train",
"--mlip-config",
config,
"--log",
log_path,
"--summary",
summary_path,
],
)
try:
Expand All @@ -102,9 +109,33 @@ def test_train(tmp_path):

assert result.exit_code == 0

assert_log_contains(log_path, includes=["Starting training", "Training complete"])

# Read train summary file and check contents
assert summary_path.exists()
with open(summary_path, encoding="utf8") as file:
train_summary = yaml.safe_load(file)

assert "command" in train_summary
assert "janus train" in train_summary["command"]
assert "start_time" in train_summary
assert "inputs" in train_summary
assert "end_time" in train_summary

assert "emissions" in train_summary
assert train_summary["emissions"] > 0

# Clean up logger
logger = logging.getLogger()
logger.handlers = [
h for h in logger.handlers if not isinstance(h, logging.FileHandler)
]


def test_train_with_foundation(tmp_path):
"""Test MLIP training raises error with foundation_model in config."""
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"
config = write_tmp_config(DATA_PATH / "mlip_train_invalid.yml", tmp_path)

result = runner.invoke(
Expand All @@ -113,6 +144,10 @@ def test_train_with_foundation(tmp_path):
"train",
"--mlip-config",
config,
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 1
Expand All @@ -121,6 +156,9 @@ def test_train_with_foundation(tmp_path):

def test_fine_tune(tmp_path):
"""Test MLIP fine-tuning."""
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

model = "test-finetuned.model"
compiled_model = "test-finetuned_compiled.model"
logs_path = "logs"
Expand All @@ -137,7 +175,16 @@ def test_fine_tune(tmp_path):

result = runner.invoke(
app,
["train", "--mlip-config", config, "--fine-tune"],
[
"train",
"--mlip-config",
config,
"--fine-tune",
"--log",
log_path,
"--summary",
summary_path,
],
)
try:
assert Path(model).exists()
Expand All @@ -155,28 +202,57 @@ def test_fine_tune(tmp_path):
shutil.rmtree(results_path, ignore_errors=True)
shutil.rmtree(checkpoints_path, ignore_errors=True)

# Clean up logger
logger = logging.getLogger()
logger.handlers = [
h for h in logger.handlers if not isinstance(h, logging.FileHandler)
]

assert result.exit_code == 0


def test_fine_tune_no_foundation():
def test_fine_tune_no_foundation(tmp_path):
"""Test MLIP fine-tuning raises errors without foundation_model."""
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

config = DATA_PATH / "mlip_fine_tune_no_foundation.yml"

result = runner.invoke(
app,
["train", "--mlip-config", config, "--fine-tune"],
[
"train",
"--mlip-config",
config,
"--fine-tune",
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 1
assert isinstance(result.exception, ValueError)


def test_fine_tune_invalid_foundation():
def test_fine_tune_invalid_foundation(tmp_path):
"""Test MLIP fine-tuning raises errors with invalid foundation_model."""
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"
config = DATA_PATH / "mlip_fine_tune_invalid_foundation.yml"

result = runner.invoke(
app,
["train", "--mlip-config", config, "--fine-tune"],
[
"train",
"--mlip-config",
config,
"--fine-tune",
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 1
assert isinstance(result.exception, ValueError)

0 comments on commit e0de4a6

Please sign in to comment.