From 8d3c8e7b928baa9e6a3bd963eb8324d706baaf3d Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 27 Feb 2025 19:58:43 +0000 Subject: [PATCH] Save MD postprocessing filenames --- janus_core/calculations/md.py | 112 +++++++++++++++++++++++----------- janus_core/cli/utils.py | 4 ++ 2 files changed, 80 insertions(+), 36 deletions(-) diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index b2499c1f..caa4913a 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -542,6 +542,16 @@ def output_files(self) -> None: if self.minimize_kwargs["write_results"] else None ) + output_files["rdfs"] = ( + self._rdf_files + if self.post_process_kwargs.get("rdf_compute", False) + else None + ) + output_files["vafs"] = ( + self._vaf_files + if self.post_process_kwargs.get("vaf_compute", False) + else None + ) return output_files @@ -739,6 +749,70 @@ def _correlations_file(self) -> str: """ return self._build_filename("cor.dat", self.param_prefix) + @property + def _rdf_files(self) -> tuple[Path]: + """ + Get RDF filenames. + + Returns + ------- + str + Filenames for RDF file. + """ + base_name = self.post_process_kwargs.get("rdf_output_file", None) + rdf_args = { + name: self.post_process_kwargs.get(key, default) + for name, (key, default) in ( + ("elements", ("rdf_elements", None)), + ("by_elements", ("rdf_by_elements", False)), + ) + } + + if rdf_args["by_elements"]: + elements = ( + tuple(sorted(set(self.struct.get_chemical_symbols()))) + if rdf_args["elements"] is None + else rdf_args["elements"] + ) + + out_paths = tuple( + self._build_filename( + "rdf.dat", + self.param_prefix, + "_".join(element), + prefix_override=base_name, + ) + for element in combinations_with_replacement(elements, 2) + ) + + else: + out_paths = ( + self._build_filename( + "rdf.dat", self.param_prefix, prefix_override=base_name + ), + ) + + return out_paths + + @property + def _vaf_files(self) -> str: + """ + Define VAF filenames. + + Returns + ------- + str + Filenames for VAF files. + """ + file_names = self.post_process_kwargs.get("vaf_output_files", None) + if not isinstance(file_names, Sequence): + file_names = (file_names,) + + return tuple( + self._build_filename("vaf.dat", self.param_prefix, filename=file_name) + for file_name in file_names + ) + def _parse_correlations(self) -> None: """Parse correlation kwargs into Correlations.""" if self.correlation_kwargs: @@ -970,7 +1044,6 @@ def _post_process(self) -> None: ana = Analysis(data) if self.post_process_kwargs.get("rdf_compute", False): - base_name = self.post_process_kwargs.get("rdf_output_file", None) rdf_args = { name: self.post_process_kwargs.get(key, default) for name, (key, default) in ( @@ -987,45 +1060,12 @@ def _post_process(self) -> None: ) rdf_args["index"] = slice_ - if rdf_args["by_elements"]: - elements = ( - tuple(sorted(set(data[0].get_chemical_symbols()))) - if rdf_args["elements"] is None - else rdf_args["elements"] - ) - - out_paths = [ - self._build_filename( - "rdf.dat", - self.param_prefix, - "_".join(element), - prefix_override=base_name, - ) - for element in combinations_with_replacement(elements, 2) - ] - - else: - out_paths = [ - self._build_filename( - "rdf.dat", self.param_prefix, prefix_override=base_name - ) - ] - - compute_rdf(data, ana, filenames=out_paths, **rdf_args) + compute_rdf(data, ana, filenames=self._rdf_files, **rdf_args) if self.post_process_kwargs.get("vaf_compute", False): - file_names = self.post_process_kwargs.get("vaf_output_files", None) use_vel = self.post_process_kwargs.get("vaf_velocities", False) fft = self.post_process_kwargs.get("vaf_fft", False) - if not isinstance(file_names, Sequence): - file_names = (file_names,) - - out_paths = tuple( - self._build_filename("vaf.dat", self.param_prefix, filename=file_name) - for file_name in file_names - ) - slice_ = ( self.post_process_kwargs.get("vaf_start", 0), self.post_process_kwargs.get("vaf_stop", None), @@ -1034,7 +1074,7 @@ def _post_process(self) -> None: compute_vaf( data, - out_paths, + self._vaf_files, use_velocities=use_vel, fft=fft, index=slice_, diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index f3a4af39..6eeecc03 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -37,6 +37,10 @@ def dict_paths_to_strs(dictionary: dict) -> None: for key, value in dictionary.items(): if isinstance(value, dict): dict_paths_to_strs(value) + elif isinstance(value, Sequence) and not isinstance(value, str): + dictionary[key] = [ + str(path) if isinstance(path, Path) else path for path in value + ] elif isinstance(value, Path): dictionary[key] = str(value)