Skip to content

Commit

Permalink
Fix for driver-config validation in combination with keypath (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
maddenp-noaa authored Aug 9, 2024
1 parent cb2700c commit 87fe26f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
21 changes: 10 additions & 11 deletions src/uwtools/drivers/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ def __init__(
}
)
self._config_full: dict = config_input.data
config_intermediate, _ = walk_key_path(self._config_full, key_path or [])
self._platform = config_intermediate.get("platform")
self._config_intermediate, _ = walk_key_path(self._config_full, key_path or [])
try:
self._config: dict = config_intermediate[self.driver_name]
self._config: dict = self._config_intermediate[self.driver_name]
except KeyError as e:
raise UWConfigError("Required '%s' block missing in config" % self.driver_name) from e
if controller:
self._config[STR.rundir] = config_intermediate[controller][STR.rundir]
self._config[STR.rundir] = self._config_intermediate[controller][STR.rundir]
self._validate(schema_file)
dryrun(enable=dry_run)

Expand Down Expand Up @@ -200,10 +199,10 @@ def _validate(self, schema_file: Optional[Path] = None) -> None:
:raises: UWConfigError if config fails validation.
"""
if schema_file:
validate_external(schema_file=schema_file, config=self.config_full)
validate_external(schema_file=schema_file, config=self._config_intermediate)
else:
validate_internal(
schema_name=self.driver_name.replace("_", "-"), config=self.config_full
schema_name=self.driver_name.replace("_", "-"), config=self._config_intermediate
)


Expand Down Expand Up @@ -390,13 +389,13 @@ def _run_resources(self) -> dict[str, Any]:
"""
Returns platform configuration data.
"""
if not self._platform:
if not (platform := self._config_intermediate.get("platform")):
raise UWConfigError("Required 'platform' block missing in config")
threads = self.config.get(STR.execution, {}).get(STR.threads)
return {
STR.account: self._platform[STR.account],
STR.account: platform[STR.account],
STR.rundir: self.rundir,
STR.scheduler: self._platform[STR.scheduler],
STR.scheduler: platform[STR.scheduler],
STR.stdout: "%s.out" % self._runscript_path.name, # config may override
**({STR.threads: threads} if threads else {}),
**self.config.get(STR.execution, {}).get(STR.batchargs, {}),
Expand Down Expand Up @@ -481,10 +480,10 @@ def _validate(self, schema_file: Optional[Path] = None) -> None:
:raises: UWConfigError if config fails validation.
"""
if schema_file:
validate_external(schema_file=schema_file, config=self.config_full)
validate_external(schema_file=schema_file, config=self._config_intermediate)
else:
validate_internal(
schema_name=self.driver_name.replace("_", "-"), config=self.config_full
schema_name=self.driver_name.replace("_", "-"), config=self._config_intermediate
)
validate_internal(schema_name=STR.platform, config=self.config_full)

Expand Down
11 changes: 9 additions & 2 deletions src/uwtools/tests/drivers/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def test_Assets_key_path(config, tmp_path):
config=config_file, dry_run=False, key_path=["foo", "bar"]
)
assert assetsobj.config == config[assetsobj.driver_name]
assert assetsobj._platform == config["platform"]


def test_Assets_leadtime(config):
Expand All @@ -236,6 +235,14 @@ def test_Assets_validate(assetsobj, caplog):
assert regex_logged(caplog, "State: Ready")


def test_Assets_validate_key_path(config, controller_schema):
config = {"a": {"b": config}}
with patch.object(ConcreteAssetsTimeInvariant, "_validate", driver.Assets._validate):
assert ConcreteAssetsTimeInvariant(
config=config, key_path=["a", "b"], schema_file=controller_schema
)


@mark.parametrize(
"base_file,update_values,expected",
[
Expand Down Expand Up @@ -442,7 +449,7 @@ def test_Driver__namelist_schema_default_disable(driverobj):


def test_Driver__run_resources_fail(driverobj):
driverobj._platform = None
del driverobj._config_intermediate["platform"]
with raises(UWConfigError) as e:
assert driverobj._run_resources
assert str(e.value) == "Required 'platform' block missing in config"
Expand Down

0 comments on commit 87fe26f

Please sign in to comment.