From e8de7624b379a1ae2f045ef2b82dd990aee43d43 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Mon, 3 Mar 2025 12:05:51 -0500 Subject: [PATCH] Changed _remove_job to use job_id to get around the bug that job_meta is missing --- .../aggregators/edge_result_accumulator.py | 4 +- nvflare/edge/emulator/device_emulator.py | 10 ++++- .../edge/emulator/device_task_processor.py | 1 - nvflare/edge/emulator/run_emulator.py | 8 +--- .../edge/emulator/sample_task_processor.py | 4 +- .../edge/executors/edge_dispatch_executor.py | 19 ++++----- .../edge/web/handlers/edge_task_handler.py | 1 - nvflare/edge/web/handlers/lcp_task_handler.py | 1 - nvflare/edge/web/handlers/sample_task_data.py | 27 +++++-------- nvflare/edge/web/models/base_model.py | 1 - nvflare/edge/web/models/device_info.py | 11 ++++- nvflare/edge/web/models/task_request.py | 8 +--- nvflare/edge/web/models/user_info.py | 11 ++++- nvflare/edge/web/routing_proxy.py | 16 ++++---- nvflare/edge/web/web_server.py | 1 + nvflare/edge/widgets/etd.py | 40 +++++++------------ nvflare/edge/widgets/etg.py | 4 +- nvflare/fuel/utils/fobs/decomposer.py | 2 +- 18 files changed, 79 insertions(+), 90 deletions(-) diff --git a/nvflare/edge/aggregators/edge_result_accumulator.py b/nvflare/edge/aggregators/edge_result_accumulator.py index b0d76fc4ac..e16220593b 100644 --- a/nvflare/edge/aggregators/edge_result_accumulator.py +++ b/nvflare/edge/aggregators/edge_result_accumulator.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.app_common.abstract.aggregator import Aggregator -import numpy as np - class EdgeResultAccumulator(Aggregator): def __init__(self): diff --git a/nvflare/edge/emulator/device_emulator.py b/nvflare/edge/emulator/device_emulator.py index 66d877cdbb..96ac363de1 100644 --- a/nvflare/edge/emulator/device_emulator.py +++ b/nvflare/edge/emulator/device_emulator.py @@ -28,8 +28,14 @@ class DeviceEmulator: - def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo, - capabilities: Optional[dict], processor: DeviceTaskProcessor): + def __init__( + self, + endpoint: str, + device_info: DeviceInfo, + user_info: UserInfo, + capabilities: Optional[dict], + processor: DeviceTaskProcessor, + ): self.device_info = device_info self.device_id = device_info.device_id self.user_info = user_info diff --git a/nvflare/edge/emulator/device_task_processor.py b/nvflare/edge/emulator/device_task_processor.py index 6de070fd04..6a41aacbab 100644 --- a/nvflare/edge/emulator/device_task_processor.py +++ b/nvflare/edge/emulator/device_task_processor.py @@ -64,4 +64,3 @@ def process_task(self, task: TaskResponse) -> dict: The result as a dict """ pass - diff --git a/nvflare/edge/emulator/run_emulator.py b/nvflare/edge/emulator/run_emulator.py index e3fa4d07e9..dea39f514b 100644 --- a/nvflare/edge/emulator/run_emulator.py +++ b/nvflare/edge/emulator/run_emulator.py @@ -27,11 +27,7 @@ def device_run(endpoint_url: str, device_info: DeviceInfo, user_info: UserInfo, processor: DeviceTaskProcessor): device_id = device_info.device_id try: - capabilities = { - "methods": ["xgboost", "cnn"], - "cpu": 16, - "gpu": 1024 - } + capabilities = {"methods": ["xgboost", "cnn"], "cpu": 16, "gpu": 1024} emulator = DeviceEmulator(endpoint_url, device_info, user_info, capabilities, processor) emulator.run() @@ -60,7 +56,7 @@ def run_emulator(endpoint_url: str, num: int): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()] + handlers=[logging.StreamHandler()], ) n = len(sys.argv) diff --git a/nvflare/edge/emulator/sample_task_processor.py b/nvflare/edge/emulator/sample_task_processor.py index 6b66c8767f..ae8cce5d49 100644 --- a/nvflare/edge/emulator/sample_task_processor.py +++ b/nvflare/edge/emulator/sample_task_processor.py @@ -48,9 +48,7 @@ def process_task(self, task: TaskResponse) -> dict: w = [0, 0, 0, 0] result = {"weights": w} elif task.task_name == "validate": - result = { - "accuracy": [0.01, 0.02, 0.03, 0.04] - } + result = {"accuracy": [0.01, 0.02, 0.03, 0.04]} else: log.error(f"Received unknown task: {task.task_name}") diff --git a/nvflare/edge/executors/edge_dispatch_executor.py b/nvflare/edge/executors/edge_dispatch_executor.py index dec3b3275d..c96646a989 100644 --- a/nvflare/edge/executors/edge_dispatch_executor.py +++ b/nvflare/edge/executors/edge_dispatch_executor.py @@ -16,7 +16,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import ReturnCode, Shareable, ReservedHeaderKey +from nvflare.apis.shareable import ReservedHeaderKey, ReturnCode, Shareable from nvflare.edge.aggregators.edge_result_accumulator import EdgeResultAccumulator from nvflare.edge.executors.ete import EdgeTaskExecutor from nvflare.edge.web.models.result_report import ResultReport @@ -27,8 +27,7 @@ class EdgeDispatchExecutor(EdgeTaskExecutor): - """This executor dispatches tasks to edge devices and wait for the response from all devices - """ + """This executor dispatches tasks to edge devices and wait for the response from all devices""" def __init__(self, wait_time=300.0, min_devices=0, aggregator_id=None): EdgeTaskExecutor.__init__(self) @@ -58,10 +57,7 @@ def setup(self, _event_type, fl_ctx: FLContext): def convert_task(self, task_data: Shareable) -> dict: """Convert task_data to a plain dict""" - return { - "weights": task_data.get("weights"), - "task_id": self.task_id - } + return {"weights": task_data.get("weights"), "task_id": self.task_id} def convert_result(self, result: dict) -> Shareable: """Convert result from device to shareable""" @@ -78,8 +74,7 @@ def handle_task_request(self, request: TaskRequest, fl_ctx: FLContext) -> TaskRe # This device already processed current task last_task_id = self.devices.get(device_id, None) if self.task_id == last_task_id: - return TaskResponse("RETRY", job_id, 30, - message=f"Task {self.task_id} is already processed by this device") + return TaskResponse("RETRY", job_id, 30, message=f"Task {self.task_id} is already processed by this device") task_done = self.current_task.get("task_done") task_data = self.convert_task(self.current_task) @@ -90,6 +85,12 @@ def handle_task_request(self, request: TaskRequest, fl_ctx: FLContext) -> TaskRe def handle_result_report(self, report: ResultReport, fl_ctx: FLContext) -> ResultResponse: """Handle result report from device""" + if report.task_id != self.task_id: + msg = f"Task {report.task_id} is already done, result ignored" + self.log_warning(fl_ctx, msg) + # Still returns OK because this late result may be useful in certain cases + return ResultResponse("OK", task_id=self.task_id, task_name=self.task_name, message=msg) + result = self.convert_result(report.result) self.aggregator.accept(result, fl_ctx) self.num_results += 1 diff --git a/nvflare/edge/web/handlers/edge_task_handler.py b/nvflare/edge/web/handlers/edge_task_handler.py index 78898b38dd..e066fc5837 100644 --- a/nvflare/edge/web/handlers/edge_task_handler.py +++ b/nvflare/edge/web/handlers/edge_task_handler.py @@ -39,4 +39,3 @@ def handle_task(self, task_request: TaskRequest) -> TaskResponse: @abstractmethod def handle_result(self, result_report: ResultReport) -> ResultResponse: pass - diff --git a/nvflare/edge/web/handlers/lcp_task_handler.py b/nvflare/edge/web/handlers/lcp_task_handler.py index a4f9c72124..80eb9f4ce6 100644 --- a/nvflare/edge/web/handlers/lcp_task_handler.py +++ b/nvflare/edge/web/handlers/lcp_task_handler.py @@ -96,4 +96,3 @@ def _handle_task_request(self, request: Any) -> dict: reply = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE) assert isinstance(reply, dict) return reply - \ No newline at end of file diff --git a/nvflare/edge/web/handlers/sample_task_data.py b/nvflare/edge/web/handlers/sample_task_data.py index b3503c3df1..0bfecb2a95 100644 --- a/nvflare/edge/web/handlers/sample_task_data.py +++ b/nvflare/edge/web/handlers/sample_task_data.py @@ -25,8 +25,14 @@ from nvflare.edge.web.models.user_info import UserInfo jobs = [ - JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "demo_job", "ExecuTorch", - job_data={"executorch_parameters": [1.2, 3.4, 5.6]}), + JobResponse( + "OK", + str(uuid.uuid4()), + str(uuid.uuid4()), + "demo_job", + "ExecuTorch", + job_data={"executorch_parameters": [1.2, 3.4, 5.6]}, + ), JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "xgb_job", "xgboost"), JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "core_job", "coreML"), JobResponse("RETRY", str(uuid.uuid4()), retry_wait=60), @@ -64,13 +70,7 @@ def handle_task_request(device_info: DeviceInfo, user_info: UserInfo, task_reque task_name = state["next_task"] task_id = state["task_id"] - reply = TaskResponse( - "OK", - session_id, - None, - task_id, - task_name, - {}) + reply = TaskResponse("OK", session_id, None, task_id, task_name, {}) return reply @@ -90,14 +90,9 @@ def handle_result_report(device_info: DeviceInfo, user_info: UserInfo, result_re status = "OK" task_id = state["task_id"] - state["next_task"] = demo_tasks[index+1] + state["next_task"] = demo_tasks[index + 1] state["task_id"] = str(uuid.uuid4()) - reply = ResultResponse( - status, - None, - session_id, - task_id, - result_report.task_name) + reply = ResultResponse(status, None, session_id, task_id, result_report.task_name) return reply diff --git a/nvflare/edge/web/models/base_model.py b/nvflare/edge/web/models/base_model.py index 2d3dabe5a9..a00473ec13 100644 --- a/nvflare/edge/web/models/base_model.py +++ b/nvflare/edge/web/models/base_model.py @@ -39,4 +39,3 @@ def get_device_id(self) -> Optional[str]: return None return device_info.get("device_id") - diff --git a/nvflare/edge/web/models/device_info.py b/nvflare/edge/web/models/device_info.py index cdb9fa261a..24719d0a0e 100644 --- a/nvflare/edge/web/models/device_info.py +++ b/nvflare/edge/web/models/device_info.py @@ -4,8 +4,15 @@ class DeviceInfo(BaseModel): """Device information""" - def __init__(self, device_id: str, app_name: str = None, app_version: str = None, - platform: str = None, platform_version: str = None, **kwargs): + def __init__( + self, + device_id: str, + app_name: str = None, + app_version: str = None, + platform: str = None, + platform_version: str = None, + **kwargs, + ): super().__init__() self.device_id = device_id self.app_name = app_name diff --git a/nvflare/edge/web/models/task_request.py b/nvflare/edge/web/models/task_request.py index 5f7efb0443..054dd1afaf 100644 --- a/nvflare/edge/web/models/task_request.py +++ b/nvflare/edge/web/models/task_request.py @@ -4,13 +4,7 @@ class TaskRequest(BaseModel): - def __init__( - self, - device_info: DeviceInfo, - user_info: UserInfo, - job_id: str, - **kwargs - ): + def __init__(self, device_info: DeviceInfo, user_info: UserInfo, job_id: str, **kwargs): super().__init__() self.device_info = device_info self.user_info = user_info diff --git a/nvflare/edge/web/models/user_info.py b/nvflare/edge/web/models/user_info.py index 3fa4524f71..573e13ff48 100644 --- a/nvflare/edge/web/models/user_info.py +++ b/nvflare/edge/web/models/user_info.py @@ -3,8 +3,15 @@ class UserInfo(BaseModel): - def __init__(self, user_id: str = None, user_name: str = None, access_token: str = None, auth_token: str = None, - auth_session: str = None, **kwargs): + def __init__( + self, + user_id: str = None, + user_name: str = None, + access_token: str = None, + auth_token: str = None, + auth_session: str = None, + **kwargs, + ): super().__init__() self.user_id = user_id self.user_name = user_name diff --git a/nvflare/edge/web/routing_proxy.py b/nvflare/edge/web/routing_proxy.py index 6e4213981b..c1b7da8cb4 100644 --- a/nvflare/edge/web/routing_proxy.py +++ b/nvflare/edge/web/routing_proxy.py @@ -19,8 +19,8 @@ from typing import Tuple from urllib.parse import urljoin -from flask import Flask, request, Response, jsonify import requests +from flask import Flask, Response, jsonify, request from nvflare.edge.web.models.api_error import ApiError from nvflare.edge.web.web_server import FilteredJSONProvider @@ -66,7 +66,7 @@ def handle_api_error(error: ApiError): mapper = LcpMapper() -@app.route('/', methods=['GET', 'POST']) +@app.route("/", methods=["GET", "POST"]) def routing_proxy(path): device_id = request.headers.get("X-Flare-Device-ID") @@ -83,7 +83,7 @@ def routing_proxy(path): try: # Prepare headers (remove 'Host' to avoid conflicts) - headers = {key: value for key, value in request.headers if key.lower() != 'host'} + headers = {key: value for key, value in request.headers if key.lower() != "host"} # Get data from the original request data = request.get_data() @@ -96,11 +96,11 @@ def routing_proxy(path): headers=headers, data=data, cookies=request.cookies, - allow_redirects=False # Do not follow redirects + allow_redirects=False, # Do not follow redirects ) # Exclude specific headers from the target response - excluded_headers = ['server', 'date', 'content-encoding', 'content-length', 'transfer-encoding', 'connection'] + excluded_headers = ["server", "date", "content-encoding", "content-length", "transfer-encoding", "connection"] headers = {name: value for name, value in resp.headers.items() if name.lower() not in excluded_headers} headers["Via"] = "edge-proxy" @@ -112,12 +112,12 @@ def routing_proxy(path): raise ApiError(500, "PROXY_ERROR", f"Proxy request failed: {str(ex)}", ex) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler()] + handlers=[logging.StreamHandler()], ) if len(sys.argv) != 3: @@ -128,4 +128,4 @@ def routing_proxy(path): port = int(sys.argv[1]) app.json = FilteredJSONProvider(app) - app.run(host='0.0.0.0', port=port, debug=False) + app.run(host="0.0.0.0", port=port, debug=False) diff --git a/nvflare/edge/web/web_server.py b/nvflare/edge/web/web_server.py index c8c4ca81ea..04328547bf 100644 --- a/nvflare/edge/web/web_server.py +++ b/nvflare/edge/web/web_server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any + from flask import Flask, jsonify from flask.json.provider import DefaultJSONProvider diff --git a/nvflare/edge/widgets/etd.py b/nvflare/edge/widgets/etd.py index 28712fd627..38e3ee3582 100644 --- a/nvflare/edge/widgets/etd.py +++ b/nvflare/edge/widgets/etd.py @@ -18,8 +18,7 @@ from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import JobMetaKey -from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey -from nvflare.edge.constants import EdgeEventType +from nvflare.edge.constants import EdgeContextKey, EdgeEventType, EdgeProtoKey from nvflare.edge.constants import Status as EdgeStatus from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey from nvflare.fuel.f3.cellnet.utils import new_cell_message @@ -73,30 +72,19 @@ def _add_job(self, job_meta: dict): jobs.append(job_id) self.job_metas[job_id] = job_meta - def _remove_job(self, job_meta: dict): + def _remove_job(self, job_id: str): with self.lock: - job_id = job_meta.get(JobMetaKey.JOB_ID) if job_id in self.job_metas: del self.job_metas[job_id] - edge_method = job_meta.get(JobMetaKey.EDGE_METHOD) - if not edge_method: - # this is not an edge job - self.logger.info(f"no edge_method in job {job_id}") - return - - jobs = self.edge_jobs.get(edge_method) - if not jobs: - self.logger.info("no edge jobs pending") - return - assert isinstance(jobs, list) - job_id = job_meta.get(JobMetaKey.JOB_ID) - self.logger.info(f"current jobs for {edge_method}: {jobs}") - if job_id in jobs: - jobs.remove(job_id) - if not jobs: - # no more jobs for this edge method - self.edge_jobs.pop(edge_method) + # Delete this job from all methods + for edge_method, jobs in list(self.edge_jobs.items()): + assert isinstance(jobs, list) + if jobs and job_id in jobs: + jobs.remove(job_id) + if not jobs: + # no more jobs for this edge method + self.edge_jobs.pop(edge_method) def _match_job(self, caps: dict): methods = caps.get("methods") @@ -127,11 +115,11 @@ def _handle_job_launched(self, event_type: str, fl_ctx: FLContext): def _handle_job_done(self, event_type: str, fl_ctx: FLContext): self.logger.info(f"handling event {event_type}") - job_meta = fl_ctx.get_prop(FLContextKey.JOB_META) - if not job_meta: - self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}") + job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID) + if not job_id: + self.logger.error(f"missing {FLContextKey.CURRENT_JOB_ID} from fl_ctx for event {event_type}") else: - self._remove_job(job_meta) + self._remove_job(job_id) def _handle_edge_job_request(self, event_type: str, fl_ctx: FLContext): self.logger.info(f"handling event {event_type}") diff --git a/nvflare/edge/widgets/etg.py b/nvflare/edge/widgets/etg.py index b95bae3709..d3c4fc2610 100644 --- a/nvflare/edge/widgets/etg.py +++ b/nvflare/edge/widgets/etg.py @@ -22,9 +22,9 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_context import FLContext from nvflare.apis.signal import Signal -from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey +from nvflare.edge.constants import EdgeContextKey from nvflare.edge.constants import EdgeEventType as EdgeEventType -from nvflare.edge.constants import Status +from nvflare.edge.constants import EdgeProtoKey, Status from nvflare.widgets.widget import Widget diff --git a/nvflare/fuel/utils/fobs/decomposer.py b/nvflare/fuel/utils/fobs/decomposer.py index 2906f3bfc0..8a6938c5bc 100644 --- a/nvflare/fuel/utils/fobs/decomposer.py +++ b/nvflare/fuel/utils/fobs/decomposer.py @@ -204,7 +204,7 @@ def supported_type(self) -> Type[T]: def decompose(self, target: T, manager: DatumManager = None) -> Any: data = {} - if hasattr(target, '__dict__'): + if hasattr(target, "__dict__"): data[DATA_CONTENT] = vars(target) if isinstance(target, dict):