Skip to content

Commit

Permalink
Works on multi GPU
Browse files Browse the repository at this point in the history
Handles scheduler in similar fashion
removes slight hack from ar_model
  • Loading branch information
Simon Adamov committed May 1, 2024
1 parent b0050b9 commit f8517e4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
12 changes: 10 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(self, args):
if self.output_std:
self.test_metrics["output_std"] = [] # Treat as metric

# For making restoring of optimizer state optional (slight hack)
self.opt_state = None
# For making restoring of optimizer state optional
self.restore_opt = args.restore_opt

# For example plotting
self.n_example_pred = args.n_example_pred
Expand Down Expand Up @@ -593,3 +593,11 @@ def on_load_checkpoint(self, checkpoint):
)
loaded_state_dict[new_key] = loaded_state_dict[old_key]
del loaded_state_dict[old_key]
if not self.restore_opt:
optimizers, lr_schedulers = self.configure_optimizers()
checkpoint["optimizer_states"] = [
opt.state_dict() for opt in optimizers
]
checkpoint["lr_schedulers"] = [
sched.state_dict() for sched in lr_schedulers
]
12 changes: 3 additions & 9 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,7 @@ def main():

# Load model parameters Use new args for model
model_class = MODELS[args.model]
if args.load:
model = model_class.load_from_checkpoint(args.load, args=args)
if args.restore_opt:
# Save for later
# Unclear if this works for multi-GPU
model.opt_state = torch.load(args.load)["optimizer_states"][0]
else:
model = model_class(args)
model = model_class(args)

prefix = "subset-" if args.subset_ds else ""
if args.eval:
Expand Down Expand Up @@ -300,13 +293,14 @@ def main():
)

print(f"Running evaluation on {args.eval}")
trainer.test(model=model, dataloaders=eval_loader)
trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load)
else:
# Train model
trainer.fit(
model=model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
ckpt_path=args.load,
)


Expand Down

0 comments on commit f8517e4

Please sign in to comment.