Skip to content

Commit

Permalink
trim the whitespace of the clients and gpu from the job simulator_run (
Browse files Browse the repository at this point in the history
…#2912)

* trim the whitespace of the clients and gpu from the job simulator_run.

* Added unit test.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
yhwen and YuanTingHsieh authored Sep 6, 2024
1 parent 9cb3b98 commit 08ca469
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nvflare/job_config/fed_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,14 @@ def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, g
+ workspace
)
if clients:
clients = self._trim_whitespace(clients)
command += " -c " + str(clients)
if n_clients:
command += " -n " + str(n_clients)
if threads:
command += " -t " + str(threads)
if gpu:
gpu = self._trim_whitespace(gpu)
command += " -gpu " + str(gpu)

new_env = os.environ.copy()
Expand Down Expand Up @@ -385,3 +387,9 @@ def _get_deploy_map(self):
deploy_map[app_name] = deploy_map.get(app_name, [])
deploy_map[app_name].append(site)
return deploy_map

def _trim_whitespace(self, string: str):
strings = string.split(",")
for i in range(len(strings)):
strings[i] = strings[i].strip()
return ",".join(strings)
8 changes: 8 additions & 0 deletions tests/unit_test/job_config/fed_job_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ def test_locate_imports(self):
with tempfile.NamedTemporaryFile(dir=cwd, suffix=".py") as dest_file:
imports = list(job_config.locate_imports(sf, dest_file=dest_file.name))
assert imports == expected

def test_trim_whitespace(self):
job_config = FedJobConfig(job_name="job_name", min_clients=1)
expected = "site-0,site-1"
assert expected == job_config._trim_whitespace("site-0,site-1")
assert expected == job_config._trim_whitespace("site-0, site-1")
assert expected == job_config._trim_whitespace(" site-0,site-1 ")
assert expected == job_config._trim_whitespace(" site-0, site-1 ")

0 comments on commit 08ca469

Please sign in to comment.