Skip to content

Commit

Permalink
Mobile Dev: Add XOR E2E
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Mar 3, 2025
1 parent 88a4a6b commit 7e50d19
Show file tree
Hide file tree
Showing 29 changed files with 667 additions and 160 deletions.
15 changes: 15 additions & 0 deletions examples/advanced/edge/edge__p_resources.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"format_version": 2,
"components": [
{
"id": "web_agent",
"path": "nvflare.edge.widgets.web_agent.WebAgent",
"args": {}
},
{
"id": "etd",
"path": "nvflare.edge.widgets.etd.EdgeTaskDispatcher",
"args": {}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"args": {
"learner_id": "learner",
"aggregator_id": "aggregator",
"aggr_timeout": 60,
"aggr_timeout": 600,
"min_responses": 2,
"wait_time_after_min_resps_received": 5
}
Expand All @@ -24,7 +24,7 @@
"id": "learner",
"path": "nvflare.edge.executors.edge_dispatch_executor.EdgeDispatchExecutor",
"args": {
"wait_time": 60,
"wait_time": 600,
"min_devices": 2,
"aggregator_id": "aggregator"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
{
"format_version": 2,
"num_rounds": 2,
"num_rounds": 1,
"workflows": [
{
"id": "edge_controller",
"path": "edge_controller.SimpleEdgeController",
"args": {
"num_rounds": "{num_rounds}",
"initial_weights": [
1.0, 2.0, 3.0, 4.0
]
"model_string_file": "/Users/yuantingh/NVFlare/examples/advanced/edge/xor_model.txt"
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from nvflare.apis.controller_spec import Task, ClientTask

from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
from nvflare.edge.aggregators.edge_result_accumulator import EdgeResultAccumulator
from nvflare.security.logging import secure_format_exception


class SimpleEdgeController(Controller):

def __init__(
self,
num_rounds: int,
initial_weights: Any
):
def __init__(self, num_rounds: int, model_string_file: str):
super().__init__()
self.num_rounds = num_rounds
self.current_round = None
self.initial_weights = initial_weights
self.aggregator = None
with open(model_string_file, "r") as f:
model_string = f.read()
self.model_string = model_string

def start_controller(self, fl_ctx: FLContext) -> None:
self.log_info(fl_ctx, "Initializing Simple mobile workflow.")
self.aggregator = EdgeResultAccumulator()

# initialize global model
fl_ctx.set_prop(AppConstants.START_ROUND, 1, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self.num_rounds, private=True, sticky=False)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.initial_weights, private=True, sticky=True)
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)

def stop_controller(self, fl_ctx: FLContext):
self.log_info(fl_ctx, "Stopping Simple mobile workflow.")
Expand All @@ -66,14 +58,17 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
return

self.log_info(fl_ctx, f"Round {self.current_round} started.")
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.initial_weights, private=True, sticky=True)
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self.current_round, private=True, sticky=True)
fl_ctx.set_prop(
AppConstants.CURRENT_ROUND,
self.current_round,
private=True,
sticky=True,
)
self.fire_event(AppEventType.ROUND_STARTED, fl_ctx)

# Create train_task
task_data = Shareable()
task_data["weights"] = self.initial_weights
task_data["task_done"] = self.current_round >= (self.num_rounds - 1)
task_data["model"] = self.model_string
task_data.set_header(AppConstants.CURRENT_ROUND, self.current_round)
task_data.set_header(AppConstants.NUM_ROUNDS, self.num_rounds)
task_data.add_cookie(AppConstants.CONTRIBUTION_ROUND, self.current_round)
Expand All @@ -95,19 +90,7 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None:
if abort_signal.triggered:
return

self.log_info(fl_ctx, "Start aggregation.")
self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx)
aggr_result = self.aggregator.aggregate(fl_ctx)
self.log_info(fl_ctx, f"Aggregation result: {aggr_result}")
fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False)
self.fire_event(AppEventType.AFTER_AGGREGATION, fl_ctx)
self.log_info(fl_ctx, "End aggregation.")

if abort_signal.triggered:
return

final_weights = aggr_result.get("weights", None)
self.log_info(fl_ctx, f"Finished Mobile Training. Final weights: {final_weights}")
self.log_info(fl_ctx, "Finished Mobile Training.")
except Exception as e:
error_msg = f"Exception in mobile control_flow: {secure_format_exception(e)}"
self.log_exception(fl_ctx, error_msg)
Expand All @@ -128,8 +111,9 @@ def process_train_result(self, client_task: ClientTask, fl_ctx: FLContext) -> No

return

accepted = self.aggregator.accept(result, fl_ctx)
accepted_msg = "ACCEPTED" if accepted else "REJECTED"
# accepted_msg = "ACCEPTED" if accepted else "REJECTED"
accepted_msg = "ACCEPTED"
self.log_info(
fl_ctx, f"Contribution from {client_name} {accepted_msg} by the aggregator at round {self.current_round}."
fl_ctx,
f"Contribution from {client_name} {accepted_msg} by the aggregator at round {self.current_round}.",
)
34 changes: 34 additions & 0 deletions examples/advanced/edge/jobs/hello_mobile/app/custom/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch.nn as nn
from torch.nn import functional as F


# Basic Net for XOR
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 10)
self.linear2 = nn.Linear(10, 2)

def forward(self, x):
return self.linear2(F.sigmoid(self.linear(x)))


# On device training requires the loss to be embedded in the model (and be the first output).
# We wrap the original model here and add the loss calculation. This will be the model we export.
class TrainingNet(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.loss = nn.CrossEntropyLoss()

def forward(self, input, label):
pred = self.net(input)
return self.loss(pred, label), pred.detach().argmax(dim=1)
2 changes: 1 addition & 1 deletion examples/advanced/edge/jobs/hello_mobile/meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
"app": ["@ALL"]
},
"min_clients": 2,
"edge_method": "cnn"
"edge_method": "xor"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"format_version": 2,
"num_rounds": 3,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.ham.HierarchicalAggregationManager",
"args": {
"learner_id": "learner",
"aggregator_id": "aggregator",
"aggr_timeout": 600,
"min_responses": 2,
"wait_time_after_min_resps_received": 5
}
}
}
],
"components": [
{
"id": "learner",
"path": "nvflare.edge.executors.edge_dispatch_executor.EdgeDispatchExecutor",
"args": {
"wait_time": 60,
"min_devices": 2,
"aggregator_id": "aggregator"
}
},
{
"id": "aggregator",
"path": "edge_json_accumulator.EdgeJsonAccumulator",
"args": {
"aggr_key": "data"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"format_version": 2,
"num_rounds": 2,
"workflows": [
{
"id": "edge_executorch_controller",
"path": "edge_executorch_controller.EdgeExecutorchController",
"args": {
"num_rounds": "{num_rounds}"
}
}
]
}
Loading

0 comments on commit 7e50d19

Please sign in to comment.