Skip to content

Commit

Permalink
Fixed several issues and the emulator runs now
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Feb 28, 2025
1 parent 77526c7 commit 88a4a6b
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 63 deletions.
66 changes: 66 additions & 0 deletions examples/advanced/edge/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Running Edge Example with Hierarchical Clients

Please follow these steps to run the Edge device emulator,

## Provision

Use tree_prov.py to generate a hierarchical NVFlare system with 2 levels and 2 clients at each level,

python tree_prov.py -p /tmp/edge_example -d 2 -w 2

This will create a deployment with 2 clients, 4 leaf-clients, 2 relays, 1 server.

This file needs to be copied to the `local` folder of each leaf clients, C11, C12, C21 and C22.

`edge__p_resources.json`:

```
{
"format_version": 2,
"components": [
{
"id": "web_agent",
"path": "nvflare.edge.widgets.web_agent.WebAgent",
"args": {}
},
{
"id": "etd",
"path": "nvflare.edge.widgets.etd.EdgeTaskDispatcher",
"args": {}
}
]
}
```

To start the system, just run the following command,
./start_all.sh

## Starting Web Proxy

To route devices to different LCP, routing_proxy is used. It's a simple proxy that routes the request to
different LCP based on checksum of the device ID. It can be started like this,

python routing_proxy.py 8000 /tmp/edge_example/lcp_map.json

The lcp_map.json file is generated by tree_prov.py.

## Example Job

The `hello_mobile` is a very simple job to test the edge functions. It only sends one task "train" and
print the aggregated results.

The job can be started as usual in NVFlare admin console.

## Run Edge Emulator

The emulator can be used to test all the features of the edge system. It handles 'train' task by simply doubling every values
in the weights.

To start the emulator, give it an endpoint URL and number of devices like this,

python run_emulator.py http://localhost:8000 16

The emulator keeps polling the LCP for job assignment. It only runs one job then quits.



Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"id": "learner",
"path": "nvflare.edge.executors.edge_dispatch_executor.EdgeDispatchExecutor",
"args": {
"wait_time": "30",
"wait_time": 60,
"min_devices": 2,
"aggregator_id": "aggregator"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False)
self.fire_event(AppEventType.TRAINING_STARTED, fl_ctx)

if self.current_round is None:
self.current_round = 1

while self.current_round < self.num_rounds:
for i in range(self.num_rounds):

self.current_round = i
if abort_signal.triggered:
return

Expand All @@ -75,6 +73,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
# Create train_task
task_data = Shareable()
task_data["weights"] = self.initial_weights
task_data["task_done"] = self.current_round >= (self.num_rounds - 1)
task_data.set_header(AppConstants.CURRENT_ROUND, self.current_round)
task_data.set_header(AppConstants.NUM_ROUNDS, self.num_rounds)
task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, self.current_round)
Expand All @@ -87,8 +86,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:

self.broadcast_and_wait(
task=train_task,
min_responses=4,
wait_time_after_min_received=10,
min_responses=1,
wait_time_after_min_received=30,
fl_ctx=fl_ctx,
abort_signal=abort_signal,
)
Expand All @@ -107,7 +106,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
if abort_signal.triggered:
return

self.log_info(fl_ctx, "Finished Mobile Training.")
final_weights = aggr_result.get("weights", None)
self.log_info(fl_ctx, f"Finished Mobile Training. Final weights: {final_weights}")
except Exception as e:
error_msg = f"Exception in mobile control_flow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
Expand Down
6 changes: 3 additions & 3 deletions nvflare/app_common/executors/ham.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def _do_execute(self, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if received == 0:
# nothing received!
self.log_info(fl_ctx, "nothing received - timeout")
return make_reply(ReturnCode.TIMEOUT)
# nothing received! This maybe ok
self.log_warning(fl_ctx, "nothing received - timeout")
# return make_reply(ReturnCode.TIMEOUT)

try:
self.log_info(fl_ctx, "return aggregation result")
Expand Down
6 changes: 5 additions & 1 deletion nvflare/edge/aggregators/edge_result_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ def __init__(self):

def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
self.log_info(fl_ctx, f"Accepting: {shareable}")
self.num_devices += 1

w = shareable.get("weights")
if w is None:
return True

self.num_devices += 1
if self.weights is None:
self.weights = w
else:
Expand Down
40 changes: 31 additions & 9 deletions nvflare/edge/emulator/device_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nvflare.edge.web.models.api_error import ApiError
from nvflare.edge.web.models.device_info import DeviceInfo
from nvflare.edge.web.models.job_response import JobResponse
from nvflare.edge.web.models.task_response import TaskResponse
from nvflare.edge.web.models.user_info import UserInfo

log = logging.getLogger(__name__)
Expand All @@ -30,6 +31,7 @@ class DeviceEmulator:
def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo,
capabilities: Optional[dict], processor: DeviceTaskProcessor):
self.device_info = device_info
self.device_id = device_info.device_id
self.user_info = user_info
self.capabilities = capabilities
self.processor = processor
Expand All @@ -44,26 +46,30 @@ def run(self):
log.info(f"Received job: {job}")

while True:
task = self.eta_api.get_task(job)
log.info(f"Received task: {task}")
task = self.fetch_task(job)
if not task:
log.info(f"Job {job.job_Id} is done")
break
log.info(f"Device:{self.device_id} Received task: {task}")

# Catch exception
result = self.processor.process_task(task)
log.info(f"Task processed. Result: {result}")
log.info(f"Device:{self.device_id} Task processed. Result: {result}")
# Check result
result_response = self.eta_api.report_result(task, result)
log.info(f"Received result response: {result_response}")
if result_response.status == "DONE":
log.info(f"Device:{self.device_id} Received result response: {result_response}")
task_done = task.get("task_done", False)
if task_done or result_response.status == "DONE":
log.info(f"Job {job.job_id} {job.job_name} is done")
break
elif result_response.status != "OK":
log.error(f"Result report for task {task.task_name} is invalid")
log.error(f"Device:{self.device_id} Result report for task {task.task_name} is invalid")
continue

log.info(f"Task {task.task_name} result reported successfully")
log.info(f"Device:{self.device_id} Task {task.task_name} result reported successfully")

self.processor.shutdown()
log.info(f"Job {job.job_name} run ended")
log.info(f"Device:{self.device_id} Job {job.job_name} run ended")

except ApiError as error:
log.error(f"Status: {error.status}\nMessage: {str(error)}\nDetails: {error.details}")
Expand All @@ -79,5 +85,21 @@ def fetch_job(self) -> JobResponse:
return job
if job.status == "RETRY":
wait = job.retry_wait if job.retry_wait else 5
log.info(f"Retrying getting job in {wait} seconds")
log.info(f"Device:{self.device_id} Retrying getting job in {wait} seconds")
time.sleep(wait)

def fetch_task(self, job: JobResponse) -> TaskResponse:

while True:
task = self.eta_api.get_task(job)
if task.status == "OK":
return task
elif task.status == "DONE":
task["task_done"] = True
return task
elif task.status == "NO_JOB":
return None
elif task.status == "RETRY":
wait = task.retry_wait if task.retry_wait else 5
log.info(f"Device:{self.device_id} Retrying getting task in {wait} seconds")
time.sleep(wait)
4 changes: 3 additions & 1 deletion nvflare/edge/emulator/eta_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def report_result(self, task: TaskResponse, result: dict) -> ResultResponse:
headers = {"Content-Type": "application/json"}
headers.update(self.common_headers)
params = {
"job_id": task.job_id,
"task_name": task.task_name,
"task_id": task.task_id,
}
Expand All @@ -80,4 +81,5 @@ def report_result(self, task: TaskResponse, result: dict) -> ResultResponse:
if code == 200:
return ResultResponse(**response.json())

raise ApiError(code, "ERROR", f"API Call failed with status code {code}", response.json())
details = {"response": response.text}
raise ApiError(code, "ERROR", f"API Call failed with status code {code}", details)
56 changes: 41 additions & 15 deletions nvflare/edge/emulator/run_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
from concurrent.futures import ThreadPoolExecutor, wait

from nvflare.edge.emulator.device_emulator import DeviceEmulator
from nvflare.edge.emulator.device_task_processor import DeviceTaskProcessor
from nvflare.edge.emulator.sample_task_processor import SampleTaskProcessor
from nvflare.edge.web.models.device_info import DeviceInfo
from nvflare.edge.web.models.user_info import UserInfo

log = logging.getLogger(__name__)


def run_emulator():
def device_run(endpoint_url: str, device_info: DeviceInfo, user_info: UserInfo, processor: DeviceTaskProcessor):
device_id = device_info.device_id
try:
capabilities = {
"methods": ["xgboost", "cnn"],
"cpu": 16,
"gpu": 1024
}
emulator = DeviceEmulator(endpoint_url, device_info, user_info, capabilities, processor)
emulator.run()

# read from JSON, a list of devices
device_info = DeviceInfo("1234", "flare_mobile", "1.0")
user_info = UserInfo("demo_id", "demo_user")
# Configure processor
processor = SampleTaskProcessor(device_info, user_info)
capabilities = {
"methods": ["xgboost", "cnn"],
"cpu": 16,
"gpu": 1024
}
endpoint = "http://localhost:4321"
emulator = DeviceEmulator(endpoint, device_info, user_info, capabilities, processor)
emulator.run()
log.info(f"Emulator run for device {device_id} ended")
except Exception as ex:
log.error(f"Device {device_id} failed to run: {ex}")


def run_emulator(endpoint_url: str, num: int):
with ThreadPoolExecutor(max_workers=num) as thread_pool:
futures = []
for i in range(num):
device_info = DeviceInfo(f"device-{i}", "flare_mobile", "1.0")
user_info = UserInfo("demo_id", "demo_user")
processor = SampleTaskProcessor(device_info, user_info)
f = thread_pool.submit(device_run, endpoint_url, device_info, user_info, processor)
futures.append(f)

wait(futures)

log.info("Emulator run ended")

Expand All @@ -48,4 +63,15 @@ def run_emulator():
handlers=[logging.StreamHandler()]
)

run_emulator()
n = len(sys.argv)
if n >= 2:
endpoint = sys.argv[1]
else:
endpoint = "http://localhost:9007"

if n >= 3:
num_devices = int(sys.argv[2])
else:
num_devices = 4

run_emulator(endpoint, num_devices)
9 changes: 6 additions & 3 deletions nvflare/edge/emulator/sample_task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ def process_task(self, task: TaskResponse) -> dict:

result = None
if task.task_name == "train":
result = {
"model": [1.0, 2.0, 3.0, 4.0]
}
weights = task.task_data["weights"]
if weights:
w = [x * 2.0 for x in weights]
else:
w = [0, 0, 0, 0]
result = {"weights": w}
elif task.task_name == "validate":
result = {
"accuracy": [0.01, 0.02, 0.03, 0.04]
Expand Down
Loading

0 comments on commit 88a4a6b

Please sign in to comment.