From 45bb06d00713f559919eb79509e435e9172bca45 Mon Sep 17 00:00:00 2001 From: k-harris27 <120191386+k-harris27@users.noreply.github.com> Date: Fri, 7 Feb 2025 21:23:47 +0000 Subject: [PATCH] Add MD progress bar as class --- janus_core/calculations/md.py | 152 +++++++++++++++++++++--- janus_core/calculations/phonons.py | 6 +- janus_core/calculations/single_point.py | 19 +-- janus_core/cli/md.py | 21 ++++ janus_core/helpers/utils.py | 63 +++++----- tests/test_md.py | 43 ++++++- tests/test_md_cli.py | 5 + 7 files changed, 246 insertions(+), 63 deletions(-) diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index cf8cca2e..8f176d52 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -10,9 +10,13 @@ from os.path import getmtime from pathlib import Path import random -from typing import Any +from typing import TYPE_CHECKING, Any from warnings import warn +# Only needed for type hints +if TYPE_CHECKING: + from rich.progress import TaskID + from ase import Atoms from ase.geometry.analysis import Analysis from ase.io import read @@ -43,7 +47,12 @@ PostProcessKwargs, ) from janus_core.helpers.struct_io import input_structs, output_structs -from janus_core.helpers.utils import none_to_dict, set_minimize_logging, write_table +from janus_core.helpers.utils import ( + ProgressBar, + none_to_dict, + set_minimize_logging, + write_table, +) from janus_core.processing.correlator import Correlation from janus_core.processing.post_process import compute_rdf, compute_vaf @@ -157,6 +166,11 @@ class MolecularDynamics(BaseCalculation): seed Random seed used by numpy.random and random functions, such as in Langevin. Default is None. + enable_progress_bar + Whether to show a progress bar. Default is False. + update_progress_every + How many timesteps between progress bar updates. + Default is steps/100, rounded up. Attributes ---------- @@ -218,6 +232,8 @@ def __init__( post_process_kwargs: PostProcessKwargs | None = None, correlation_kwargs: list[CorrelationKwargs] | None = None, seed: int | None = None, + enable_progress_bar: bool = False, + update_progress_every: int | None = None, ) -> None: """ Initialise molecular dynamics simulation configuration. @@ -325,6 +341,11 @@ def __init__( seed Random seed used by numpy.random and random functions, such as in Langevin. Default is None. + enable_progress_bar + Whether to show a progress bar. Default is False. + update_progress_every + How many timesteps between progress bar updates. + Default is steps/100, rounded up. """ ( read_kwargs, @@ -372,6 +393,8 @@ def __init__( self.post_process_kwargs = post_process_kwargs self.correlation_kwargs = correlation_kwargs self.seed = seed + self.enable_progress_bar = enable_progress_bar + self.update_progress_every = update_progress_every if "append" in self.write_kwargs: raise ValueError("`append` cannot be specified when writing files") @@ -421,6 +444,15 @@ def __init__( "Temperature ramp requested for ensemble with no thermostat" ) + if self.ramp_temp: + self.heating_steps_per_temp = int(self.temp_time // self.timestep) + + # Always include start temperature in ramp, and include end temperature + # if separated by an integer number of temperature steps + self.heating_n_temps = int( + 1 + abs(self.temp_end - self.temp_start) // self.temp_step + ) + # Validate start and end temperatures if self.temp_start < 0 or self.temp_end < 0: raise ValueError("Start and end temperatures must be positive") @@ -430,6 +462,9 @@ def __init__( "Temperature ramp step time cannot be less than 1 timestep" ) + if self.update_progress_every is None: + self.update_progress_every = np.ceil(self.steps / 100) + # Read last image by default read_kwargs.setdefault("index", -1) @@ -1044,6 +1079,89 @@ def _set_target_temperature(self, temperature: float): else: raise ValueError("Temperature set for ensemble with no thermostat.") + def _init_progress_bar(self) -> ProgressBar: + """ + Initialise MD progress bar. + + Returns + ------- + ProgressBar + Object used for managing progress bars. + """ + if not self.enable_progress_bar: + self._progress_bar = ProgressBar(disable=True) + return self._progress_bar + self._progress_bar = ProgressBar() + total_steps = self.steps + + # Set total expected MD steps. + if self.ramp_temp: + total_steps += self.heating_n_temps * self.heating_steps_per_temp + # Heating steps at 0 are skipped. + if np.isclose(self.temp_start, 0.0): + total_steps -= self.heating_steps_per_temp + + total_task_id = self._progress_bar.add_task( + "Performing MD simulation...", + total=total_steps, + completed=self.offset, + ) + ramp_task_id = None + if self.ramp_temp: + ramp_task_id = self._progress_bar.add_task( + "", + visible=False, + ) + + update_func = partial(self._update_progress_bar, total_task_id, ramp_task_id) + self.dyn.attach(update_func, interval=self.update_progress_every) + # Also ensure progress is updated at the end + self.dyn.attach(update_func, interval=-total_steps) + return self._progress_bar + + def _update_progress_bar( + self, total_task_id: TaskID, ramp_task_id: TaskID | None = None + ): + """ + Update the progress bar for MD run. + + Parameters + ---------- + total_task_id + Task ID tracking overall simulation progress. + ramp_task_id + Task ID tracking progress of individual temperature ramp steps (Optional). + """ + current_step = self.dyn.nsteps + self.offset + + self._progress_bar.update(total_task_id, completed=current_step) + + if ramp_task_id: + current_ramp_step = current_step // self.heating_steps_per_temp + + # Account for MD temperature steps at T=0 K being skipped. + heating_n_temps = self.heating_n_temps + if np.isclose(self.temp_start, 0.0): + heating_n_temps -= 1 + + if current_ramp_step < heating_n_temps: + description = f"Temperature ramp ({self.temp} K)..." + completed = current_step % self.heating_steps_per_temp + total = self.heating_steps_per_temp + else: + description = f"Constant temperature ({self.temp} K)..." + completed = current_step - heating_n_temps * self.heating_steps_per_temp + total = self.steps + self._progress_bar.update( + ramp_task_id, + description=description, + completed=completed, + total=total, + visible=True, + ) + + self._progress_bar.refresh() + def run(self) -> None: """Run molecular dynamics simulation and/or temperature ramp.""" unit_keys = ( @@ -1080,9 +1198,10 @@ def run(self) -> None: if self.minimize and self.minimize_every > 0: self.dyn.attach(self._optimize_structure, interval=self.minimize_every) - # Note current time - self.struct.info["real_time"] = datetime.datetime.now() - self._run_dynamics() + with self._init_progress_bar(): + # Note current time + self.struct.info["real_time"] = datetime.datetime.now() + self._run_dynamics() if self.post_process_kwargs: self._post_process() @@ -1102,27 +1221,24 @@ def _run_dynamics(self) -> None: # Run temperature ramp if self.ramp_temp: - heating_steps = int(self.temp_time // self.timestep) - if self.logger and not np.isclose(self.temp_time % self.timestep, 0.0): - rounded_temp_step = heating_steps * self.timestep / units.fs + rounded_temp_step = ( + self.heating_steps_per_temp * self.timestep / units.fs + ) self.logger.info( "Temperature ramp step time rounded to nearest timestep " f"({rounded_temp_step:.5} fs)" ) - # Always include start temperature in ramp, and include end temperature - # if separated by an integer number of temperature steps - n_temps = int(1 + abs(self.temp_end - self.temp_start) // self.temp_step) - # Add or subtract temperatures ramp_sign = 1 if (self.temp_end - self.temp_start) > 0 else -1 temps = [ - self.temp_start + ramp_sign * i * self.temp_step for i in range(n_temps) + self.temp_start + ramp_sign * i * self.temp_step + for i in range(self.heating_n_temps) ] if self.restart: - ramp_steps_completed = self.offset // heating_steps + ramp_steps_completed = self.offset // self.heating_steps_per_temp ramp_steps_completed = min(ramp_steps_completed, len(temps)) if isclose(self.temp_start, 0.0): # T~0K steps do not run any MD, so are not included in the offset. @@ -1139,10 +1255,10 @@ def _run_dynamics(self) -> None: first_step = True for temp in temps: self.temp = temp - steps = heating_steps + steps = self.heating_steps_per_temp if first_step: first_step = False - steps -= self.offset % heating_steps + steps -= self.offset % self.heating_steps_per_temp self._set_velocity_distribution() if isclose(temp, 0.0): # Calculate forces and energies to be output @@ -1167,7 +1283,9 @@ def _run_dynamics(self) -> None: if self.restart and self.ramp_temp: # Take the ramp time off the offset for MD. # If restarting during the ramp, MD has no offset. - md_offset = max(0, md_offset - heating_steps * n_temps) + md_offset = max( + 0, md_offset - self.heating_steps_per_temp * self.heating_n_temps + ) # Run MD if self.steps > 0: diff --git a/janus_core/calculations/phonons.py b/janus_core/calculations/phonons.py index 806f017c..de1279e7 100644 --- a/janus_core/calculations/phonons.py +++ b/janus_core/calculations/phonons.py @@ -27,7 +27,7 @@ PathLike, PhononCalcs, ) -from janus_core.helpers.utils import none_to_dict, set_minimize_logging, track_progress +from janus_core.helpers.utils import ProgressBar, none_to_dict, set_minimize_logging class Phonons(BaseCalculation): @@ -426,8 +426,8 @@ def calc_force_constants( disp_supercells = phonon.supercells_with_displacements if self.enable_progress_bar: - disp_supercells = track_progress( - disp_supercells, "Computing displacements..." + disp_supercells = ProgressBar().track( + disp_supercells, description="Computing displacements..." ) phonon.forces = [ diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index 0474d731..b7200038 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -22,7 +22,7 @@ ) from janus_core.helpers.mlip_calculators import check_calculator from janus_core.helpers.struct_io import output_structs -from janus_core.helpers.utils import none_to_dict, track_progress +from janus_core.helpers.utils import ProgressBar, none_to_dict class SinglePoint(BaseCalculation): @@ -239,8 +239,8 @@ def _get_potential_energy(self) -> MaybeList[float]: if isinstance(self.struct, Sequence): struct_sequence = self.struct if self.enable_progress_bar: - struct_sequence = track_progress( - struct_sequence, "Computing potential energies..." + struct_sequence = ProgressBar().track( + struct_sequence, description="Computing potential energies..." ) return [struct.get_potential_energy() for struct in struct_sequence] @@ -258,7 +258,9 @@ def _get_forces(self) -> MaybeList[ndarray]: if isinstance(self.struct, Sequence): struct_sequence = self.struct if self.enable_progress_bar: - struct_sequence = track_progress(struct_sequence, "Computing forces...") + struct_sequence = ProgressBar().track( + struct_sequence, description="Computing forces..." + ) return [struct.get_forces() for struct in struct_sequence] return self.struct.get_forces() @@ -275,8 +277,8 @@ def _get_stress(self) -> MaybeList[ndarray]: if isinstance(self.struct, Sequence): struct_sequence = self.struct if self.enable_progress_bar: - struct_sequence = track_progress( - struct_sequence, "Computing stresses..." + struct_sequence = ProgressBar().track( + struct_sequence, description="Computing stresses..." ) return [struct.get_stress() for struct in struct_sequence] @@ -319,8 +321,9 @@ def _get_hessian(self) -> MaybeList[ndarray]: if isinstance(self.struct, Sequence): struct_sequence = self.struct if self.enable_progress_bar: - struct_sequence = track_progress( - struct_sequence, "Computing Hessian..." + print("There should be a progress bar...") + struct_sequence = ProgressBar().track( + struct_sequence, description="Computing Hessian..." ) return [self._calc_hessian(struct) for struct in struct_sequence] diff --git a/janus_core/cli/md.py b/janus_core/cli/md.py index d802cd1e..c946d573 100644 --- a/janus_core/cli/md.py +++ b/janus_core/cli/md.py @@ -215,6 +215,20 @@ def md( tracker: Annotated[ bool, Option(help="Whether to save carbon emissions of calculation") ] = True, + enable_progress_bar: Annotated[ + bool, + Option( + "--enable-progress-bar/--disable-progress-bar", + help="Whether to show progress bar.", + ), + ] = True, + update_progress_every: Annotated[ + int, + Option( + help="How many timesteps between progress bar updates. " + "Default is steps/100, rounded up." + ), + ] = None, summary: Summary = None, ) -> None: """ @@ -343,6 +357,11 @@ def md( tracker Whether to save carbon emissions of calculation in log file and summary. Default is True. + enable_progress_bar + Whether to show progress bar. + update_progress_every + How many timesteps between progress bar updates. + Default is steps/100, rounded up. summary Path to save summary of inputs, start/end time, and carbon emissions. Default is inferred from the name of the structure file. @@ -456,6 +475,8 @@ def md( "write_kwargs": write_kwargs, "post_process_kwargs": post_process_kwargs, "seed": seed, + "enable_progress_bar": enable_progress_bar, + "update_progress_every": update_progress_every, } # Instantiate MD ensemble diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index abb549df..082bd6a7 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -368,45 +368,44 @@ def _dump_csv( print(",".join(map(format, cols, formats)), file=file) -def track_progress(sequence: Sequence | Iterable, description: str) -> Iterable: +class ProgressBar(Progress): """ - Track the progress of iterating over a sequence. + Progress bar with preset formatting. - This is done by displaying a progress bar in the console using the rich library. - The function is an iterator over the sequence, updating the progress bar each - iteration. + Inherits from `rich.progress.Progress`, providing preset formatting. Parameters ---------- - sequence - The sequence to iterate over. Must support "len". - description - The text to display to the left of the progress bar. - - Yields - ------ - Iterable - An iterable of the values in the sequence. + **kwargs + Keyword arguments passed on to `rich.progress.Progress`. """ - text_column = TextColumn("{task.description}") - bar_column = BarColumn( - bar_width=None, - complete_style=Style(color="#FBBB10"), - finished_style=Style(color="#E38408"), - ) - completion_column = MofNCompleteColumn() - time_column = TimeRemainingColumn() - progress = Progress( - text_column, - bar_column, - completion_column, - time_column, - expand=True, - auto_refresh=False, - ) - with progress: - yield from progress.track(sequence, description=description) + def __init__(self, **kwargs): + """ + Initialise a `rich` progress bar with preset formatting. + + Parameters + ---------- + **kwargs + Keyword arguments passed on to `rich.progress.Progress`. + """ + text_column = TextColumn("{task.description}") + bar_column = BarColumn( + bar_width=None, + complete_style=Style(color="#FBBB10"), + finished_style=Style(color="#E38408"), + ) + completion_column = MofNCompleteColumn() + time_column = TimeRemainingColumn() + super().__init__( + text_column, + bar_column, + completion_column, + time_column, + expand=True, + auto_refresh=False, + **kwargs, + ) def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None: diff --git a/tests/test_md.py b/tests/test_md.py index 6c5f4524..8f964688 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -717,7 +717,7 @@ def test_stats(tmp_path, ensemble, tag): @pytest.mark.parametrize("ensemble", ensembles_with_thermostat) -def test_heating(tmp_path, ensemble): +def test_heating(tmp_path, capsys, ensemble): """Test heating with no MD.""" file_prefix = tmp_path / "NaCl-heating" final_file = tmp_path / "NaCl-heating-final.extxyz" @@ -741,6 +741,7 @@ def test_heating(tmp_path, ensemble): temp_step=20, temp_time=2, log_kwargs={"filename": log_file}, + enable_progress_bar=True, ) md.run() assert_log_contains( @@ -754,6 +755,9 @@ def test_heating(tmp_path, ensemble): assert final_file.exists() + # Check progress bar has completed. + assert "━━ 2/2" in capsys.readouterr().out + @pytest.mark.parametrize("ensemble", ensembles_without_thermostat) def test_no_thermostat_heating(tmp_path, ensemble): @@ -807,7 +811,7 @@ def test_noramp_heating(tmp_path, ensemble): @pytest.mark.parametrize("ensemble", ensembles_with_thermostat) -def test_heating_md(tmp_path, ensemble): +def test_heating_md(tmp_path, capsys, ensemble): """Test heating followed by MD.""" file_prefix = tmp_path / "NaCl-heating" stats_path = tmp_path / "NaCl-heating-stats.dat" @@ -830,6 +834,7 @@ def test_heating_md(tmp_path, ensemble): temp_step=10, temp_time=2, log_kwargs={"filename": log_file}, + enable_progress_bar=True, ) md.run() assert_log_contains( @@ -860,6 +865,11 @@ def test_heating_md(tmp_path, ensemble): assert stat_data.units[0] == "" assert stat_data.units[target_t_col] == "K" + # Check progress bar has completed. + out = capsys.readouterr().out + assert "━━ 9/9" in out # Total progress + assert "━━ 5/5" in out # Const T progress + def test_heating_restart(tmp_path): """Test restarting during temperature ramp.""" @@ -1164,7 +1174,7 @@ def test_logging(tmp_path): assert single_point.struct.info["emissions"] > 0 -def test_auto_restart(tmp_path): +def test_auto_restart(tmp_path, capsys): """Test auto restarting simulation.""" # tmp_path for all files other than restart # Include T300.0 to test Path.stem vs Path.name @@ -1223,6 +1233,7 @@ def test_auto_restart(tmp_path): traj_every=1, final_file=final_path, log_kwargs={"filename": log_file}, + enable_progress_bar=True, ) assert_log_contains(log_file, includes="Auto restart successful") @@ -1251,6 +1262,9 @@ def test_auto_restart(tmp_path): final_traj = read(traj_path, index=":") assert len(final_traj) == 8 + # Check progress bar has completed. + assert "━━ 7/7" in capsys.readouterr().out + finally: restart_path.unlink(missing_ok=True) @@ -1362,3 +1376,26 @@ def test_set_info(tmp_path): final_struct = read(traj_path, index="-1") assert npt.struct.info["density"] == pytest.approx(2.120952627887493) assert final_struct.info["density"] == pytest.approx(2.120952627887493) + + +@pytest.mark.parametrize("ensemble, tag", test_data) +def test_progress_bar_complete(tmp_path, capsys, ensemble, tag): + """Test progress bar completes in all ensembles.""" + file_prefix = tmp_path / f"Cl4Na4-{tag}-T300.0" + + single_point = SinglePoint( + struct=DATA_PATH / "NaCl.cif", + arch="mace", + calc_kwargs={"model": MODEL_PATH}, + ) + md = ensemble( + struct=single_point.struct, + steps=2, + file_prefix=file_prefix, + enable_progress_bar=True, + ) + + md.run() + + # Check progress bar has completed. + assert "━━ 2/2" in capsys.readouterr().out diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index 0cd0315c..934795ef 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -140,6 +140,8 @@ def test_md(ensemble): for prop, units in expected_units.items(): assert atoms.info["units"][prop] == units + assert "━━ 2/2" in result.output + finally: final_path.unlink(missing_ok=True) restart_path.unlink(missing_ok=True) @@ -786,6 +788,9 @@ def test_auto_restart(tmp_path): assert len(lines) == 6 assert int(lines[-1].split()[0]) == 7 + # Check progress bar counted restart steps correctly + assert "7/7" in result.stdout + def test_no_carbon(tmp_path): """Test disabling carbon tracking."""