diff --git a/examples/advanced/edge/edge__p_resources.json b/examples/advanced/edge/edge__p_resources.json new file mode 100644 index 0000000000..d5760562bf --- /dev/null +++ b/examples/advanced/edge/edge__p_resources.json @@ -0,0 +1,15 @@ +{ + "format_version": 2, + "components": [ + { + "id": "web_agent", + "path": "nvflare.edge.widgets.web_agent.WebAgent", + "args": {} + }, + { + "id": "etd", + "path": "nvflare.edge.widgets.etd.EdgeTaskDispatcher", + "args": {} + } + ] +} diff --git a/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_client.json b/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_client.json index d4e374c279..d52b6f5a18 100644 --- a/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_client.json +++ b/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_client.json @@ -12,7 +12,7 @@ "args": { "learner_id": "learner", "aggregator_id": "aggregator", - "aggr_timeout": 60, + "aggr_timeout": 600, "min_responses": 2, "wait_time_after_min_resps_received": 5 } @@ -24,7 +24,7 @@ "id": "learner", "path": "nvflare.edge.executors.edge_dispatch_executor.EdgeDispatchExecutor", "args": { - "wait_time": 60, + "wait_time": 600, "min_devices": 2, "aggregator_id": "aggregator" } diff --git a/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_server.json b/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_server.json index 49f0852fb0..b0f6ea7b5a 100644 --- a/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_server.json +++ b/examples/advanced/edge/jobs/hello_mobile/app/config/config_fed_server.json @@ -1,15 +1,13 @@ { "format_version": 2, - "num_rounds": 2, + "num_rounds": 1, "workflows": [ { "id": "edge_controller", "path": "edge_controller.SimpleEdgeController", "args": { "num_rounds": "{num_rounds}", - "initial_weights": [ - 1.0, 2.0, 3.0, 4.0 - ] + "model_string_file": "/Users/yuantingh/NVFlare/examples/advanced/edge/xor_model.txt" } } ] diff --git a/examples/advanced/edge/jobs/hello_mobile/app/custom/edge_controller.py b/examples/advanced/edge/jobs/hello_mobile/app/custom/edge_controller.py index acae5f96a7..466cd8aae1 100644 --- a/examples/advanced/edge/jobs/hello_mobile/app/custom/edge_controller.py +++ b/examples/advanced/edge/jobs/hello_mobile/app/custom/edge_controller.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any -from nvflare.apis.controller_spec import Task, ClientTask + +from nvflare.apis.controller_spec import ClientTask, Task from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import Controller @@ -21,32 +21,24 @@ from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType -from nvflare.edge.aggregators.edge_result_accumulator import EdgeResultAccumulator from nvflare.security.logging import secure_format_exception class SimpleEdgeController(Controller): - def __init__( - self, - num_rounds: int, - initial_weights: Any - ): + def __init__(self, num_rounds: int, model_string_file: str): super().__init__() self.num_rounds = num_rounds self.current_round = None - self.initial_weights = initial_weights - self.aggregator = None + with open(model_string_file, "r") as f: + model_string = f.read() + self.model_string = model_string def start_controller(self, fl_ctx: FLContext) -> None: self.log_info(fl_ctx, "Initializing Simple mobile workflow.") - self.aggregator = EdgeResultAccumulator() - # initialize global model fl_ctx.set_prop(AppConstants.START_ROUND, 1, private=True, sticky=True) fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False) - fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.initial_weights, private=True, sticky=True) - self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx) def stop_controller(self, fl_ctx: FLContext): self.log_info(fl_ctx, "Stopping Simple mobile workflow.") @@ -66,14 +58,17 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: return self.log_info(fl_ctx, f"Round {self.current_round} started.") - fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.initial_weights, private=True, sticky=True) - fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.current_round, private=True, sticky=True) + fl_ctx.set_prop( + AppConstants.CURRENT_ROUND, + self.current_round, + private=True, + sticky=True, + ) self.fire_event(AppEventType.ROUND_STARTED, fl_ctx) # Create train_task task_data = Shareable() - task_data["weights"] = self.initial_weights - task_data["task_done"] = self.current_round >= (self.num_rounds - 1) + task_data["model"] = self.model_string task_data.set_header(AppConstants.CURRENT_ROUND, self.current_round) task_data.set_header(AppConstants.NUM_ROUNDS, self.num_rounds) task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, self.current_round) @@ -95,19 +90,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: if abort_signal.triggered: return - self.log_info(fl_ctx, "Start aggregation.") - self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx) - aggr_result = self.aggregator.aggregate(fl_ctx) - self.log_info(fl_ctx, f"Aggregation result: {aggr_result}") - fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False) - self.fire_event(AppEventType.AFTER_AGGREGATION, fl_ctx) - self.log_info(fl_ctx, "End aggregation.") - - if abort_signal.triggered: - return - - final_weights = aggr_result.get("weights", None) - self.log_info(fl_ctx, f"Finished Mobile Training. Final weights: {final_weights}") + self.log_info(fl_ctx, "Finished Mobile Training.") except Exception as e: error_msg = f"Exception in mobile control_flow: {secure_format_exception(e)}" self.log_exception(fl_ctx, error_msg) @@ -128,8 +111,9 @@ def process_train_result(self, client_task: ClientTask, fl_ctx: FLContext) -> No return - accepted = self.aggregator.accept(result, fl_ctx) - accepted_msg = "ACCEPTED" if accepted else "REJECTED" + # accepted_msg = "ACCEPTED" if accepted else "REJECTED" + accepted_msg = "ACCEPTED" self.log_info( - fl_ctx, f"Contribution from {client_name} {accepted_msg} by the aggregator at round {self.current_round}." + fl_ctx, + f"Contribution from {client_name} {accepted_msg} by the aggregator at round {self.current_round}.", ) diff --git a/examples/advanced/edge/jobs/hello_mobile/app/custom/net.py b/examples/advanced/edge/jobs/hello_mobile/app/custom/net.py new file mode 100644 index 0000000000..3c84238e7c --- /dev/null +++ b/examples/advanced/edge/jobs/hello_mobile/app/custom/net.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch.nn as nn +from torch.nn import functional as F + + +# Basic Net for XOR +class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 10) + self.linear2 = nn.Linear(10, 2) + + def forward(self, x): + return self.linear2(F.sigmoid(self.linear(x))) + + +# On device training requires the loss to be embedded in the model (and be the first output). +# We wrap the original model here and add the loss calculation. This will be the model we export. +class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label), pred.detach().argmax(dim=1) diff --git a/examples/advanced/edge/jobs/hello_mobile/meta.json b/examples/advanced/edge/jobs/hello_mobile/meta.json index de8bd290d3..41a2b81537 100644 --- a/examples/advanced/edge/jobs/hello_mobile/meta.json +++ b/examples/advanced/edge/jobs/hello_mobile/meta.json @@ -5,5 +5,5 @@ "app": ["@ALL"] }, "min_clients": 2, - "edge_method": "cnn" + "edge_method": "xor" } diff --git a/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_client.json b/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_client.json new file mode 100644 index 0000000000..100aba7a1c --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_client.json @@ -0,0 +1,40 @@ +{ + "format_version": 2, + "num_rounds": 3, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_common.executors.ham.HierarchicalAggregationManager", + "args": { + "learner_id": "learner", + "aggregator_id": "aggregator", + "aggr_timeout": 600, + "min_responses": 2, + "wait_time_after_min_resps_received": 5 + } + } + } + ], + "components": [ + { + "id": "learner", + "path": "nvflare.edge.executors.edge_dispatch_executor.EdgeDispatchExecutor", + "args": { + "wait_time": 60, + "min_devices": 2, + "aggregator_id": "aggregator" + } + }, + { + "id": "aggregator", + "path": "edge_json_accumulator.EdgeJsonAccumulator", + "args": { + "aggr_key": "data" + } + } + ] +} diff --git a/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_server.json b/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_server.json new file mode 100644 index 0000000000..f2239e6cea --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/config/config_fed_server.json @@ -0,0 +1,13 @@ +{ + "format_version": 2, + "num_rounds": 2, + "workflows": [ + { + "id": "edge_executorch_controller", + "path": "edge_executorch_controller.EdgeExecutorchController", + "args": { + "num_rounds": "{num_rounds}" + } + } + ] +} diff --git a/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_executorch_controller.py b/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_executorch_controller.py new file mode 100644 index 0000000000..b5ac75963d --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_executorch_controller.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +from typing import Any, Dict + +import torch +from edge_json_accumulator import EdgeJsonAccumulator +from executorch_export import export_model +from model import Net, TrainingNet +from torch import Tensor + +from nvflare.apis.controller_spec import ClientTask, Task +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.app_event_type import AppEventType +from nvflare.security.logging import secure_format_exception + + +class EdgeExecutorchController(Controller): + def __init__( + self, + num_rounds: int, + ): + super().__init__() + self.model = TrainingNet(Net()) + self.num_rounds = num_rounds + self.current_round = None + self.aggregator = None + + def start_controller(self, fl_ctx: FLContext) -> None: + self.log_info(fl_ctx, "Initializing ExecuTorch mobile workflow.") + self.aggregator = EdgeJsonAccumulator(aggr_key="data") + + # initialize global model + fl_ctx.set_prop(AppConstants.START_ROUND, 1, private=True, sticky=True) + fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False) + + def stop_controller(self, fl_ctx: FLContext): + self.log_info(fl_ctx, "Stopping ExecuTorch mobile workflow.") + + def _tensor_from_json(self, tensor_data: Dict[str, Any], divide_factor: int) -> Dict[str, Tensor]: + """Convert JSON tensor data to PyTorch tensors.""" + grad_dict = {} + for key, value in tensor_data.items(): + tensor = torch.Tensor(value["data"]).reshape(value["sizes"]) + grad_dict[key] = tensor / divide_factor + print("get grad dict:", grad_dict) + return grad_dict + + def _update_model(self, aggregated_grads: Dict[str, Tensor]) -> None: + """Update model weights using aggregated gradients.""" + for key, param in self.model.state_dict().items(): + if key in aggregated_grads: + self.model.state_dict()[key] -= aggregated_grads[key] + + def _export_current_model(self) -> bytes: + """Export current model in ExecutorTorch format.""" + print("model is", self.model.state_dict()) + input_tensor = torch.randn(1, 2) + label_tensor = torch.ones(1, dtype=torch.int64) + model_buffer = export_model(self.model, input_tensor, label_tensor).buffer + base64_encoded = base64.b64encode(model_buffer).decode("utf-8") + return base64_encoded + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: + try: + self.log_info(fl_ctx, "Beginning Executorch mobile training phase.") + + fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False) + self.fire_event(AppEventType.TRAINING_STARTED, fl_ctx) + + for i in range(self.num_rounds): + self.current_round = i + if abort_signal.triggered: + return + + self.log_info(fl_ctx, f"Round {self.current_round} started.") + fl_ctx.set_prop( + AppConstants.CURRENT_ROUND, + self.current_round, + private=True, + sticky=True, + ) + self.fire_event(AppEventType.ROUND_STARTED, fl_ctx) + + # Create task and send global model to clients + encoded_buffer = self._export_current_model() + + # Compose shareable + task_data = Shareable() + task_data["model"] = encoded_buffer + task_data.set_header(AppConstants.CURRENT_ROUND, self.current_round) + task_data.set_header(AppConstants.NUM_ROUNDS, self.num_rounds) + task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, self.current_round) + + train_task = Task( + name="train", + data=task_data, + result_received_cb=self.process_train_result, + ) + + self.broadcast_and_wait( + task=train_task, + min_responses=2, + wait_time_after_min_received=10, + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + if abort_signal.triggered: + return + + self.log_info(fl_ctx, "Start aggregation.") + self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx) + aggr_result = self.aggregator.aggregate(fl_ctx) + self.log_info(fl_ctx, f"Aggregation result: {aggr_result}") + fl_ctx.set_prop( + AppConstants.AGGREGATION_RESULT, + aggr_result, + private=True, + sticky=False, + ) + self.fire_event(AppEventType.AFTER_AGGREGATION, fl_ctx) + self.log_info(fl_ctx, "End aggregation.") + + # reset aggregator + self.aggregator.reset(fl_ctx) + + # Convert aggregated gradients to PyTorch tensors + divide_factor = aggr_result["num_devices"] + aggregated_grads = self._tensor_from_json(aggr_result["result"], divide_factor) + self.log_info(fl_ctx, f"Aggregated gradients as Tensor: {aggregated_grads}") + + # Update model weights using aggregated gradients + self._update_model(aggregated_grads) + + if abort_signal.triggered: + return + + final_weights = self.model.state_dict() + self.log_info(fl_ctx, f"Finished Mobile Training. Final weights: {final_weights}") + except Exception as e: + error_msg = f"Exception in mobile control_flow: {secure_format_exception(e)}" + self.log_exception(fl_ctx, error_msg) + self.system_panic(error_msg, fl_ctx) + + def process_train_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: + result = client_task.result + client_name = client_task.client.name + rc = result.get_return_code() + + # Raise errors if bad peer context or execution exception. + if rc and rc != ReturnCode.OK: + self.system_panic( + f"Result from {client_name} is bad, error code: {rc}. " + f"{self.__class__.__name__} exiting at round {self.current_round}.", + fl_ctx=fl_ctx, + ) + + return + + accepted = self.aggregator.accept(result, fl_ctx) + accepted_msg = "ACCEPTED" if accepted else "REJECTED" + self.log_info( + fl_ctx, + f"Contribution from {client_name} {accepted_msg} by the aggregator at round {self.current_round}.", + ) diff --git a/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_json_accumulator.py b/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_json_accumulator.py new file mode 100644 index 0000000000..af26cd44fc --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/custom/edge_json_accumulator.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.aggregator import Aggregator + + +class EdgeJsonAccumulator(Aggregator): + def __init__(self, aggr_key: str): + Aggregator.__init__(self) + self.weights = None + self.num_devices = 0 + self.aggr_key = aggr_key + + def _aggregate(self, weight_base, weight_to_add): + # aggregates the dict on items with the aggregation key + # iteratively find the key and add the values + for key, sub_object in weight_base.items(): + if isinstance(sub_object, dict): + sub_to_add = weight_to_add.get(key) + self._aggregate(sub_object, sub_to_add) + if self.aggr_key in weight_base: + weight_base[self.aggr_key] = np.add(weight_base[self.aggr_key], weight_to_add[self.aggr_key]) + return weight_base + + def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: + self.log_info(fl_ctx, f"Accepting: {shareable}") + weight_to_add = shareable.get("result") + if weight_to_add is None: + return True + + # bottom level does not have num_devices + # in which case num_devices_to_add is 1 + num_devices_to_add = shareable.get("num_devices") + if num_devices_to_add is None: + num_devices_to_add = 1 + self.num_devices += num_devices_to_add + + # add new weights to the existing weights + if self.weights is None: + self.weights = weight_to_add + else: + self.weights = self._aggregate(self.weights, weight_to_add) + + return True + + def reset(self, fl_ctx: FLContext): + self.weights = None + self.num_devices = 0 + + def aggregate(self, fl_ctx: FLContext) -> Shareable: + return Shareable({"result": self.weights, "num_devices": self.num_devices}) diff --git a/examples/advanced/edge/jobs/xor_mobile/app/custom/executorch_export.py b/examples/advanced/edge/jobs/xor_mobile/app/custom/executorch_export.py new file mode 100644 index 0000000000..25078c2120 --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/custom/executorch_export.py @@ -0,0 +1,16 @@ +from executorch.exir import to_edge +from torch.export import export +from torch.export.experimental import _export_forward_backward + + +def export_model(net, input_tensor_example, label_tensor_example): + # Captures the forward graph. The graph will look similar to the model definition now. + # Will move to export_for_training soon which is the api planned to be supported in the long term. + ep = export(net, (input_tensor_example, label_tensor_example), strict=True) + # Captures the backward graph. The exported_program now contains the joint forward and backward graph. + ep = _export_forward_backward(ep) + # Lower the graph to edge dialect. + ep = to_edge(ep) + # Lower the graph to executorch. + ep = ep.to_executorch() + return ep diff --git a/examples/advanced/edge/jobs/xor_mobile/app/custom/model.py b/examples/advanced/edge/jobs/xor_mobile/app/custom/model.py new file mode 100644 index 0000000000..3c84238e7c --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/app/custom/model.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch.nn as nn +from torch.nn import functional as F + + +# Basic Net for XOR +class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 10) + self.linear2 = nn.Linear(10, 2) + + def forward(self, x): + return self.linear2(F.sigmoid(self.linear(x))) + + +# On device training requires the loss to be embedded in the model (and be the first output). +# We wrap the original model here and add the loss calculation. This will be the model we export. +class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label), pred.detach().argmax(dim=1) diff --git a/examples/advanced/edge/jobs/xor_mobile/meta.json b/examples/advanced/edge/jobs/xor_mobile/meta.json new file mode 100644 index 0000000000..ca3918da35 --- /dev/null +++ b/examples/advanced/edge/jobs/xor_mobile/meta.json @@ -0,0 +1,9 @@ +{ + "name": "xor_mobile", + "resource_spec": {}, + "deploy_map": { + "app": ["@ALL"] + }, + "min_clients": 2, + "edge_method": "xor" +} diff --git a/nvflare/edge/emulator/device_emulator.py b/nvflare/edge/emulator/device_emulator.py index 66d877cdbb..96ac363de1 100644 --- a/nvflare/edge/emulator/device_emulator.py +++ b/nvflare/edge/emulator/device_emulator.py @@ -28,8 +28,14 @@ class DeviceEmulator: - def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo, - capabilities: Optional[dict], processor: DeviceTaskProcessor): + def __init__( + self, + endpoint: str, + device_info: DeviceInfo, + user_info: UserInfo, + capabilities: Optional[dict], + processor: DeviceTaskProcessor, + ): self.device_info = device_info self.device_id = device_info.device_id self.user_info = user_info diff --git a/nvflare/edge/emulator/device_task_processor.py b/nvflare/edge/emulator/device_task_processor.py index 6de070fd04..6a41aacbab 100644 --- a/nvflare/edge/emulator/device_task_processor.py +++ b/nvflare/edge/emulator/device_task_processor.py @@ -64,4 +64,3 @@ def process_task(self, task: TaskResponse) -> dict: The result as a dict """ pass - diff --git a/nvflare/edge/emulator/eta_api.py b/nvflare/edge/emulator/eta_api.py index 0f9937c8c4..322344abf1 100644 --- a/nvflare/edge/emulator/eta_api.py +++ b/nvflare/edge/emulator/eta_api.py @@ -58,12 +58,39 @@ def get_task(self, job: JobResponse) -> TaskResponse: params = { "job_id": job.job_id, } - response = requests.get(url, params=params, headers=self.common_headers) - code = response.status_code - if code == 200: - return TaskResponse(**response.json()) - raise ApiError(code, "ERROR", f"API Call failed with status code {code}", response.json()) + try: + response = requests.get(url, params=params, headers=self.common_headers) + code = response.status_code + + # Debug logging + print(f"Request URL: {url}") + print(f"Request params: {params}") + print(f"Response status: {code}") + print(f"Response content: {response.content}") + + if code == 200: + try: + json_data = response.json() + return TaskResponse(**json_data) + except ValueError as e: + raise ApiError( + code, + "ERROR", + f"Invalid JSON response: {str(e)}", + {"raw_content": response.content.decode("utf-8", errors="ignore")}, + ) + + # Handle non-200 responses + try: + error_data = response.json() if response.content else {} + except ValueError: + error_data = {"raw_content": response.content.decode("utf-8", errors="ignore")} + + raise ApiError(code, "ERROR", f"API Call failed with status code {code}", error_data) + + except requests.exceptions.RequestException as e: + raise ApiError(500, "ERROR", f"Request failed: {str(e)}", {"error": str(e)}) def report_result(self, task: TaskResponse, result: dict) -> ResultResponse: url = urljoin(self.endpoint, "result") diff --git a/nvflare/edge/emulator/run_emulator.py b/nvflare/edge/emulator/run_emulator.py index e3fa4d07e9..f5f54739b7 100644 --- a/nvflare/edge/emulator/run_emulator.py +++ b/nvflare/edge/emulator/run_emulator.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import traceback import logging import sys from concurrent.futures import ThreadPoolExecutor, wait from nvflare.edge.emulator.device_emulator import DeviceEmulator from nvflare.edge.emulator.device_task_processor import DeviceTaskProcessor -from nvflare.edge.emulator.sample_task_processor import SampleTaskProcessor +from nvflare.edge.emulator.xor_task_processor import XorTaskProcessor from nvflare.edge.web.models.device_info import DeviceInfo from nvflare.edge.web.models.user_info import UserInfo @@ -27,16 +29,13 @@ def device_run(endpoint_url: str, device_info: DeviceInfo, user_info: UserInfo, processor: DeviceTaskProcessor): device_id = device_info.device_id try: - capabilities = { - "methods": ["xgboost", "cnn"], - "cpu": 16, - "gpu": 1024 - } + capabilities = {"methods": ["xor"], "cpu": 16, "gpu": 1024} emulator = DeviceEmulator(endpoint_url, device_info, user_info, capabilities, processor) emulator.run() log.info(f"Emulator run for device {device_id} ended") except Exception as ex: + traceback.print_exc() log.error(f"Device {device_id} failed to run: {ex}") @@ -46,7 +45,7 @@ def run_emulator(endpoint_url: str, num: int): for i in range(num): device_info = DeviceInfo(f"device-{i}", "flare_mobile", "1.0") user_info = UserInfo("demo_id", "demo_user") - processor = SampleTaskProcessor(device_info, user_info) + processor = XorTaskProcessor(device_info, user_info) f = thread_pool.submit(device_run, endpoint_url, device_info, user_info, processor) futures.append(f) @@ -60,7 +59,7 @@ def run_emulator(endpoint_url: str, num: int): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()] + handlers=[logging.StreamHandler()], ) n = len(sys.argv) diff --git a/nvflare/edge/emulator/sample_task_processor.py b/nvflare/edge/emulator/sample_task_processor.py deleted file mode 100644 index 6b66c8767f..0000000000 --- a/nvflare/edge/emulator/sample_task_processor.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from nvflare.edge.emulator.device_task_processor import DeviceTaskProcessor -from nvflare.edge.web.models.device_info import DeviceInfo -from nvflare.edge.web.models.job_response import JobResponse -from nvflare.edge.web.models.task_response import TaskResponse -from nvflare.edge.web.models.user_info import UserInfo - -log = logging.getLogger(__name__) - - -class SampleTaskProcessor(DeviceTaskProcessor): - def __init__(self, device_info: DeviceInfo, user_info: UserInfo): - super().__init__(device_info, user_info) - self.job_id = None - self.job_name = None - - def setup(self, job: JobResponse) -> None: - self.job_id = job.job_id - self.job_name = job.job_name - # job.job_data contains data needed to set up the training - - def shutdown(self) -> None: - pass - - def process_task(self, task: TaskResponse) -> dict: - log.info(f"Processing task {task.task_name}") - - result = None - if task.task_name == "train": - weights = task.task_data["weights"] - if weights: - w = [x * 2.0 for x in weights] - else: - w = [0, 0, 0, 0] - result = {"weights": w} - elif task.task_name == "validate": - result = { - "accuracy": [0.01, 0.02, 0.03, 0.04] - } - else: - log.error(f"Received unknown task: {task.task_name}") - - return result diff --git a/nvflare/edge/emulator/train_xor b/nvflare/edge/emulator/train_xor new file mode 100755 index 0000000000..29c755def4 Binary files /dev/null and b/nvflare/edge/emulator/train_xor differ diff --git a/nvflare/edge/emulator/xor_task_processor.py b/nvflare/edge/emulator/xor_task_processor.py new file mode 100644 index 0000000000..b0b69466d8 --- /dev/null +++ b/nvflare/edge/emulator/xor_task_processor.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import logging +import os +import shutil +import subprocess + +import torch + +from nvflare.edge.emulator.device_task_processor import DeviceTaskProcessor +from nvflare.edge.web.models.device_info import DeviceInfo +from nvflare.edge.web.models.job_response import JobResponse +from nvflare.edge.web.models.task_response import TaskResponse +from nvflare.edge.web.models.user_info import UserInfo + +log = logging.getLogger(__name__) + +SOURCE_BINARY = "train_xor" + + +def save_to_pte(model_string: str, filename: str): + binary_data = base64.b64decode(model_string) + with open(filename, "wb") as f: + f.write(binary_data) + + +def run_training_with_timeout(train_program: str, model_path: str, result_path: str, timeout_seconds: int = 300) -> int: + try: + process = subprocess.Popen( + [train_program, "--model_path", model_path, "--output_path", result_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for process to complete with timeout + stdout, stderr = process.communicate(timeout=timeout_seconds) + + if process.returncode != 0: + print(f"Error output: {stderr}") + raise subprocess.CalledProcessError(process.returncode, ["./train_xor"], stdout, stderr) + + print(f"Output: {stdout}") + return process.returncode + + except subprocess.TimeoutExpired: + process.kill() + print("Training timed out") + raise + except Exception as e: + print(f"Error during training: {e}") + raise + + +def read_training_result(result_path: str = "training_result.json"): + try: + with open(result_path, "r") as f: + results = json.load(f) + + return results + + except FileNotFoundError: + print(f"Could not find file: {result_path}") + raise + except json.JSONDecodeError: + print(f"Error parsing JSON file: {result_path}") + raise + except Exception as e: + print(f"Unexpected error: {e}") + raise + + +class XorTaskProcessor(DeviceTaskProcessor): + def __init__(self, device_info: DeviceInfo, user_info: UserInfo): + super().__init__(device_info, user_info) + self.job_id = None + self.job_name = None + self.device_info = device_info + + device_io_dir = f"{device_info.device_id}_output" + os.makedirs(device_io_dir, exist_ok=True) + self.model_path = os.path.abspath(os.path.join(device_io_dir, "xor.pte")) + self.result_path = os.path.abspath(os.path.join(device_io_dir, "training_result.json")) + self.train_binary = os.path.abspath(os.path.join(device_io_dir, "train.xor")) + self._setup_train_program() + + def _setup_train_program(self): + if not os.path.exists(self.train_binary): + shutil.copy2(SOURCE_BINARY, self.train_binary) + # Make it executable + os.chmod(self.train_binary, 0o755) + + def setup(self, job: JobResponse) -> None: + self.job_id = job.job_id + self.job_name = job.job_name + + def shutdown(self) -> None: + pass + + def process_task(self, task: TaskResponse) -> dict: + log.info(f"Processing task {task}") + + # Local training or validation + result = None + if task.task_name == "train": + # save received pte + save_to_pte(task.task_data["task_data"], self.model_path) + try: + result = run_training_with_timeout( + self.train_binary, self.model_path, self.result_path, timeout_seconds=600 + ) + print("Training completed successfully") + except subprocess.TimeoutExpired: + print("Training took too long and was terminated") + except subprocess.CalledProcessError as e: + print(f"Training failed with return code {e.returncode}") + except Exception as e: + print(f"Training Unexpected error: {e}") + + try: + diff_dict = read_training_result(self.result_path) + + except Exception as e: + print(f"Failed to read results: {e}") + raise + + result = { + "result": diff_dict, + } + else: + log.error(f"Received unknown task: {task.task_name}") + + return result diff --git a/nvflare/edge/executors/edge_dispatch_executor.py b/nvflare/edge/executors/edge_dispatch_executor.py index dec3b3275d..09457d317c 100644 --- a/nvflare/edge/executors/edge_dispatch_executor.py +++ b/nvflare/edge/executors/edge_dispatch_executor.py @@ -27,8 +27,7 @@ class EdgeDispatchExecutor(EdgeTaskExecutor): - """This executor dispatches tasks to edge devices and wait for the response from all devices - """ + """This executor dispatches tasks to edge devices and wait for the response from all devices""" def __init__(self, wait_time=300.0, min_devices=0, aggregator_id=None): EdgeTaskExecutor.__init__(self) @@ -57,11 +56,9 @@ def setup(self, _event_type, fl_ctx: FLContext): def convert_task(self, task_data: Shareable) -> dict: """Convert task_data to a plain dict""" + # TODO: do we want to do this? - return { - "weights": task_data.get("weights"), - "task_id": self.task_id - } + return {"task_data": task_data["model"], "task_id": self.task_id} def convert_result(self, result: dict) -> Shareable: """Convert result from device to shareable""" @@ -78,8 +75,7 @@ def handle_task_request(self, request: TaskRequest, fl_ctx: FLContext) -> TaskRe # This device already processed current task last_task_id = self.devices.get(device_id, None) if self.task_id == last_task_id: - return TaskResponse("RETRY", job_id, 30, - message=f"Task {self.task_id} is already processed by this device") + return TaskResponse("RETRY", job_id, 30, message=f"Task {self.task_id} is already processed by this device") task_done = self.current_task.get("task_done") task_data = self.convert_task(self.current_task) diff --git a/nvflare/edge/web/handlers/edge_task_handler.py b/nvflare/edge/web/handlers/edge_task_handler.py index 78898b38dd..e066fc5837 100644 --- a/nvflare/edge/web/handlers/edge_task_handler.py +++ b/nvflare/edge/web/handlers/edge_task_handler.py @@ -39,4 +39,3 @@ def handle_task(self, task_request: TaskRequest) -> TaskResponse: @abstractmethod def handle_result(self, result_report: ResultReport) -> ResultResponse: pass - diff --git a/nvflare/edge/web/handlers/lcp_task_handler.py b/nvflare/edge/web/handlers/lcp_task_handler.py index a4f9c72124..d4badd90e4 100644 --- a/nvflare/edge/web/handlers/lcp_task_handler.py +++ b/nvflare/edge/web/handlers/lcp_task_handler.py @@ -66,7 +66,7 @@ def handle_task(self, task_request: TaskRequest) -> TaskResponse: data = reply.get(EdgeProtoKey.DATA) response = data.get("response") elif status == Status.NO_JOB: - self.logger_error(f"Job {task_request.job_id} is done") + self.logger.error(f"Job {task_request.job_id} is done") response = TaskResponse("NO_JOB", retry_wait=30, job_id=task_request.job_id) else: self.logger.error(f"Task request for {task_request.job_id} failed with status {status}") @@ -96,4 +96,3 @@ def _handle_task_request(self, request: Any) -> dict: reply = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) assert isinstance(reply, dict) return reply - \ No newline at end of file diff --git a/nvflare/edge/web/handlers/sample_task_data.py b/nvflare/edge/web/handlers/sample_task_data.py index b3503c3df1..0bfecb2a95 100644 --- a/nvflare/edge/web/handlers/sample_task_data.py +++ b/nvflare/edge/web/handlers/sample_task_data.py @@ -25,8 +25,14 @@ from nvflare.edge.web.models.user_info import UserInfo jobs = [ - JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "demo_job", "ExecuTorch", - job_data={"executorch_parameters": [1.2, 3.4, 5.6]}), + JobResponse( + "OK", + str(uuid.uuid4()), + str(uuid.uuid4()), + "demo_job", + "ExecuTorch", + job_data={"executorch_parameters": [1.2, 3.4, 5.6]}, + ), JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "xgb_job", "xgboost"), JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "core_job", "coreML"), JobResponse("RETRY", str(uuid.uuid4()), retry_wait=60), @@ -64,13 +70,7 @@ def handle_task_request(device_info: DeviceInfo, user_info: UserInfo, task_reque task_name = state["next_task"] task_id = state["task_id"] - reply = TaskResponse( - "OK", - session_id, - None, - task_id, - task_name, - {}) + reply = TaskResponse("OK", session_id, None, task_id, task_name, {}) return reply @@ -90,14 +90,9 @@ def handle_result_report(device_info: DeviceInfo, user_info: UserInfo, result_re status = "OK" task_id = state["task_id"] - state["next_task"] = demo_tasks[index+1] + state["next_task"] = demo_tasks[index + 1] state["task_id"] = str(uuid.uuid4()) - reply = ResultResponse( - status, - None, - session_id, - task_id, - result_report.task_name) + reply = ResultResponse(status, None, session_id, task_id, result_report.task_name) return reply diff --git a/nvflare/edge/web/models/base_model.py b/nvflare/edge/web/models/base_model.py index 2d3dabe5a9..a00473ec13 100644 --- a/nvflare/edge/web/models/base_model.py +++ b/nvflare/edge/web/models/base_model.py @@ -39,4 +39,3 @@ def get_device_id(self) -> Optional[str]: return None return device_info.get("device_id") - diff --git a/nvflare/edge/web/models/device_info.py b/nvflare/edge/web/models/device_info.py index cdb9fa261a..24719d0a0e 100644 --- a/nvflare/edge/web/models/device_info.py +++ b/nvflare/edge/web/models/device_info.py @@ -4,8 +4,15 @@ class DeviceInfo(BaseModel): """Device information""" - def __init__(self, device_id: str, app_name: str = None, app_version: str = None, - platform: str = None, platform_version: str = None, **kwargs): + def __init__( + self, + device_id: str, + app_name: str = None, + app_version: str = None, + platform: str = None, + platform_version: str = None, + **kwargs, + ): super().__init__() self.device_id = device_id self.app_name = app_name diff --git a/nvflare/edge/web/models/task_request.py b/nvflare/edge/web/models/task_request.py index 5f7efb0443..054dd1afaf 100644 --- a/nvflare/edge/web/models/task_request.py +++ b/nvflare/edge/web/models/task_request.py @@ -4,13 +4,7 @@ class TaskRequest(BaseModel): - def __init__( - self, - device_info: DeviceInfo, - user_info: UserInfo, - job_id: str, - **kwargs - ): + def __init__(self, device_info: DeviceInfo, user_info: UserInfo, job_id: str, **kwargs): super().__init__() self.device_info = device_info self.user_info = user_info diff --git a/nvflare/edge/web/models/user_info.py b/nvflare/edge/web/models/user_info.py index 3fa4524f71..573e13ff48 100644 --- a/nvflare/edge/web/models/user_info.py +++ b/nvflare/edge/web/models/user_info.py @@ -3,8 +3,15 @@ class UserInfo(BaseModel): - def __init__(self, user_id: str = None, user_name: str = None, access_token: str = None, auth_token: str = None, - auth_session: str = None, **kwargs): + def __init__( + self, + user_id: str = None, + user_name: str = None, + access_token: str = None, + auth_token: str = None, + auth_session: str = None, + **kwargs, + ): super().__init__() self.user_id = user_id self.user_name = user_name diff --git a/nvflare/edge/web/routing_proxy.py b/nvflare/edge/web/routing_proxy.py index 6e4213981b..6ef0eb0b09 100644 --- a/nvflare/edge/web/routing_proxy.py +++ b/nvflare/edge/web/routing_proxy.py @@ -66,7 +66,7 @@ def handle_api_error(error: ApiError): mapper = LcpMapper() -@app.route('/', methods=['GET', 'POST']) +@app.route("/", methods=["GET", "POST"]) def routing_proxy(path): device_id = request.headers.get("X-Flare-Device-ID") @@ -83,7 +83,7 @@ def routing_proxy(path): try: # Prepare headers (remove 'Host' to avoid conflicts) - headers = {key: value for key, value in request.headers if key.lower() != 'host'} + headers = {key: value for key, value in request.headers if key.lower() != "host"} # Get data from the original request data = request.get_data() @@ -96,11 +96,11 @@ def routing_proxy(path): headers=headers, data=data, cookies=request.cookies, - allow_redirects=False # Do not follow redirects + allow_redirects=False, # Do not follow redirects ) # Exclude specific headers from the target response - excluded_headers = ['server', 'date', 'content-encoding', 'content-length', 'transfer-encoding', 'connection'] + excluded_headers = ["server", "date", "content-encoding", "content-length", "transfer-encoding", "connection"] headers = {name: value for name, value in resp.headers.items() if name.lower() not in excluded_headers} headers["Via"] = "edge-proxy" @@ -112,12 +112,12 @@ def routing_proxy(path): raise ApiError(500, "PROXY_ERROR", f"Proxy request failed: {str(ex)}", ex) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()] + handlers=[logging.StreamHandler()], ) if len(sys.argv) != 3: @@ -128,4 +128,4 @@ def routing_proxy(path): port = int(sys.argv[1]) app.json = FilteredJSONProvider(app) - app.run(host='0.0.0.0', port=port, debug=False) + app.run(host="0.0.0.0", port=port, debug=False)