Skip to content

Commit

Permalink
Changed _remove_job to use job_id to get around the bug that job_meta…
Browse files Browse the repository at this point in the history
… is missing
  • Loading branch information
nvidianz committed Mar 3, 2025
1 parent f8be5d7 commit e8de762
Show file tree
Hide file tree
Showing 18 changed files with 79 additions and 90 deletions.
4 changes: 2 additions & 2 deletions nvflare/edge/aggregators/edge_result_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator

import numpy as np


class EdgeResultAccumulator(Aggregator):
def __init__(self):
Expand Down
10 changes: 8 additions & 2 deletions nvflare/edge/emulator/device_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@

class DeviceEmulator:

def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo,
capabilities: Optional[dict], processor: DeviceTaskProcessor):
def __init__(
self,
endpoint: str,
device_info: DeviceInfo,
user_info: UserInfo,
capabilities: Optional[dict],
processor: DeviceTaskProcessor,
):
self.device_info = device_info
self.device_id = device_info.device_id
self.user_info = user_info
Expand Down
1 change: 0 additions & 1 deletion nvflare/edge/emulator/device_task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,3 @@ def process_task(self, task: TaskResponse) -> dict:
The result as a dict
"""
pass

8 changes: 2 additions & 6 deletions nvflare/edge/emulator/run_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
def device_run(endpoint_url: str, device_info: DeviceInfo, user_info: UserInfo, processor: DeviceTaskProcessor):
device_id = device_info.device_id
try:
capabilities = {
"methods": ["xgboost", "cnn"],
"cpu": 16,
"gpu": 1024
}
capabilities = {"methods": ["xgboost", "cnn"], "cpu": 16, "gpu": 1024}
emulator = DeviceEmulator(endpoint_url, device_info, user_info, capabilities, processor)
emulator.run()

Expand Down Expand Up @@ -60,7 +56,7 @@ def run_emulator(endpoint_url: str, num: int):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()]
handlers=[logging.StreamHandler()],
)

n = len(sys.argv)
Expand Down
4 changes: 1 addition & 3 deletions nvflare/edge/emulator/sample_task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def process_task(self, task: TaskResponse) -> dict:
w = [0, 0, 0, 0]
result = {"weights": w}
elif task.task_name == "validate":
result = {
"accuracy": [0.01, 0.02, 0.03, 0.04]
}
result = {"accuracy": [0.01, 0.02, 0.03, 0.04]}
else:
log.error(f"Received unknown task: {task.task_name}")

Expand Down
19 changes: 10 additions & 9 deletions nvflare/edge/executors/edge_dispatch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, ReservedHeaderKey
from nvflare.apis.shareable import ReservedHeaderKey, ReturnCode, Shareable
from nvflare.edge.aggregators.edge_result_accumulator import EdgeResultAccumulator
from nvflare.edge.executors.ete import EdgeTaskExecutor
from nvflare.edge.web.models.result_report import ResultReport
Expand All @@ -27,8 +27,7 @@


class EdgeDispatchExecutor(EdgeTaskExecutor):
"""This executor dispatches tasks to edge devices and wait for the response from all devices
"""
"""This executor dispatches tasks to edge devices and wait for the response from all devices"""

def __init__(self, wait_time=300.0, min_devices=0, aggregator_id=None):
EdgeTaskExecutor.__init__(self)
Expand Down Expand Up @@ -58,10 +57,7 @@ def setup(self, _event_type, fl_ctx: FLContext):
def convert_task(self, task_data: Shareable) -> dict:
"""Convert task_data to a plain dict"""

return {
"weights": task_data.get("weights"),
"task_id": self.task_id
}
return {"weights": task_data.get("weights"), "task_id": self.task_id}

def convert_result(self, result: dict) -> Shareable:
"""Convert result from device to shareable"""
Expand All @@ -78,8 +74,7 @@ def handle_task_request(self, request: TaskRequest, fl_ctx: FLContext) -> TaskRe
# This device already processed current task
last_task_id = self.devices.get(device_id, None)
if self.task_id == last_task_id:
return TaskResponse("RETRY", job_id, 30,
message=f"Task {self.task_id} is already processed by this device")
return TaskResponse("RETRY", job_id, 30, message=f"Task {self.task_id} is already processed by this device")

task_done = self.current_task.get("task_done")
task_data = self.convert_task(self.current_task)
Expand All @@ -90,6 +85,12 @@ def handle_task_request(self, request: TaskRequest, fl_ctx: FLContext) -> TaskRe
def handle_result_report(self, report: ResultReport, fl_ctx: FLContext) -> ResultResponse:
"""Handle result report from device"""

if report.task_id != self.task_id:
msg = f"Task {report.task_id} is already done, result ignored"
self.log_warning(fl_ctx, msg)
# Still returns OK because this late result may be useful in certain cases
return ResultResponse("OK", task_id=self.task_id, task_name=self.task_name, message=msg)

result = self.convert_result(report.result)
self.aggregator.accept(result, fl_ctx)
self.num_results += 1
Expand Down
1 change: 0 additions & 1 deletion nvflare/edge/web/handlers/edge_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ def handle_task(self, task_request: TaskRequest) -> TaskResponse:
@abstractmethod
def handle_result(self, result_report: ResultReport) -> ResultResponse:
pass

1 change: 0 additions & 1 deletion nvflare/edge/web/handlers/lcp_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ def _handle_task_request(self, request: Any) -> dict:
reply = fl_ctx.get_prop(EdgeContextKey.REPLY_TO_EDGE)
assert isinstance(reply, dict)
return reply

27 changes: 11 additions & 16 deletions nvflare/edge/web/handlers/sample_task_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@
from nvflare.edge.web.models.user_info import UserInfo

jobs = [
JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "demo_job", "ExecuTorch",
job_data={"executorch_parameters": [1.2, 3.4, 5.6]}),
JobResponse(
"OK",
str(uuid.uuid4()),
str(uuid.uuid4()),
"demo_job",
"ExecuTorch",
job_data={"executorch_parameters": [1.2, 3.4, 5.6]},
),
JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "xgb_job", "xgboost"),
JobResponse("OK", str(uuid.uuid4()), str(uuid.uuid4()), "core_job", "coreML"),
JobResponse("RETRY", str(uuid.uuid4()), retry_wait=60),
Expand Down Expand Up @@ -64,13 +70,7 @@ def handle_task_request(device_info: DeviceInfo, user_info: UserInfo, task_reque
task_name = state["next_task"]
task_id = state["task_id"]

reply = TaskResponse(
"OK",
session_id,
None,
task_id,
task_name,
{})
reply = TaskResponse("OK", session_id, None, task_id, task_name, {})

return reply

Expand All @@ -90,14 +90,9 @@ def handle_result_report(device_info: DeviceInfo, user_info: UserInfo, result_re
status = "OK"

task_id = state["task_id"]
state["next_task"] = demo_tasks[index+1]
state["next_task"] = demo_tasks[index + 1]
state["task_id"] = str(uuid.uuid4())

reply = ResultResponse(
status,
None,
session_id,
task_id,
result_report.task_name)
reply = ResultResponse(status, None, session_id, task_id, result_report.task_name)

return reply
1 change: 0 additions & 1 deletion nvflare/edge/web/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ def get_device_id(self) -> Optional[str]:
return None

return device_info.get("device_id")

11 changes: 9 additions & 2 deletions nvflare/edge/web/models/device_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
class DeviceInfo(BaseModel):
"""Device information"""

def __init__(self, device_id: str, app_name: str = None, app_version: str = None,
platform: str = None, platform_version: str = None, **kwargs):
def __init__(
self,
device_id: str,
app_name: str = None,
app_version: str = None,
platform: str = None,
platform_version: str = None,
**kwargs,
):
super().__init__()
self.device_id = device_id
self.app_name = app_name
Expand Down
8 changes: 1 addition & 7 deletions nvflare/edge/web/models/task_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@


class TaskRequest(BaseModel):
def __init__(
self,
device_info: DeviceInfo,
user_info: UserInfo,
job_id: str,
**kwargs
):
def __init__(self, device_info: DeviceInfo, user_info: UserInfo, job_id: str, **kwargs):
super().__init__()
self.device_info = device_info
self.user_info = user_info
Expand Down
11 changes: 9 additions & 2 deletions nvflare/edge/web/models/user_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@

class UserInfo(BaseModel):

def __init__(self, user_id: str = None, user_name: str = None, access_token: str = None, auth_token: str = None,
auth_session: str = None, **kwargs):
def __init__(
self,
user_id: str = None,
user_name: str = None,
access_token: str = None,
auth_token: str = None,
auth_session: str = None,
**kwargs,
):
super().__init__()
self.user_id = user_id
self.user_name = user_name
Expand Down
16 changes: 8 additions & 8 deletions nvflare/edge/web/routing_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from typing import Tuple
from urllib.parse import urljoin

from flask import Flask, request, Response, jsonify
import requests
from flask import Flask, Response, jsonify, request

from nvflare.edge.web.models.api_error import ApiError
from nvflare.edge.web.web_server import FilteredJSONProvider
Expand Down Expand Up @@ -66,7 +66,7 @@ def handle_api_error(error: ApiError):
mapper = LcpMapper()


@app.route('/<path:path>', methods=['GET', 'POST'])
@app.route("/<path:path>", methods=["GET", "POST"])
def routing_proxy(path):

device_id = request.headers.get("X-Flare-Device-ID")
Expand All @@ -83,7 +83,7 @@ def routing_proxy(path):

try:
# Prepare headers (remove 'Host' to avoid conflicts)
headers = {key: value for key, value in request.headers if key.lower() != 'host'}
headers = {key: value for key, value in request.headers if key.lower() != "host"}

# Get data from the original request
data = request.get_data()
Expand All @@ -96,11 +96,11 @@ def routing_proxy(path):
headers=headers,
data=data,
cookies=request.cookies,
allow_redirects=False # Do not follow redirects
allow_redirects=False, # Do not follow redirects
)

Check failure

Code scanning / CodeQL

Full server-side request forgery Critical

The full URL of this request depends on a
user-provided value
.

# Exclude specific headers from the target response
excluded_headers = ['server', 'date', 'content-encoding', 'content-length', 'transfer-encoding', 'connection']
excluded_headers = ["server", "date", "content-encoding", "content-length", "transfer-encoding", "connection"]
headers = {name: value for name, value in resp.headers.items() if name.lower() not in excluded_headers}
headers["Via"] = "edge-proxy"

Expand All @@ -112,12 +112,12 @@ def routing_proxy(path):
raise ApiError(500, "PROXY_ERROR", f"Proxy request failed: {str(ex)}", ex)


if __name__ == '__main__':
if __name__ == "__main__":

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()]
handlers=[logging.StreamHandler()],
)

if len(sys.argv) != 3:
Expand All @@ -128,4 +128,4 @@ def routing_proxy(path):
port = int(sys.argv[1])

app.json = FilteredJSONProvider(app)
app.run(host='0.0.0.0', port=port, debug=False)
app.run(host="0.0.0.0", port=port, debug=False)
1 change: 1 addition & 0 deletions nvflare/edge/web/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from flask import Flask, jsonify
from flask.json.provider import DefaultJSONProvider

Expand Down
40 changes: 14 additions & 26 deletions nvflare/edge/widgets/etd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey
from nvflare.edge.constants import EdgeEventType
from nvflare.edge.constants import EdgeContextKey, EdgeEventType, EdgeProtoKey
from nvflare.edge.constants import Status as EdgeStatus
from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey
from nvflare.fuel.f3.cellnet.utils import new_cell_message
Expand Down Expand Up @@ -73,30 +72,19 @@ def _add_job(self, job_meta: dict):
jobs.append(job_id)
self.job_metas[job_id] = job_meta

def _remove_job(self, job_meta: dict):
def _remove_job(self, job_id: str):
with self.lock:
job_id = job_meta.get(JobMetaKey.JOB_ID)
if job_id in self.job_metas:
del self.job_metas[job_id]
edge_method = job_meta.get(JobMetaKey.EDGE_METHOD)
if not edge_method:
# this is not an edge job
self.logger.info(f"no edge_method in job {job_id}")
return

jobs = self.edge_jobs.get(edge_method)
if not jobs:
self.logger.info("no edge jobs pending")
return

assert isinstance(jobs, list)
job_id = job_meta.get(JobMetaKey.JOB_ID)
self.logger.info(f"current jobs for {edge_method}: {jobs}")
if job_id in jobs:
jobs.remove(job_id)
if not jobs:
# no more jobs for this edge method
self.edge_jobs.pop(edge_method)
# Delete this job from all methods
for edge_method, jobs in list(self.edge_jobs.items()):
assert isinstance(jobs, list)
if jobs and job_id in jobs:
jobs.remove(job_id)
if not jobs:
# no more jobs for this edge method
self.edge_jobs.pop(edge_method)

def _match_job(self, caps: dict):
methods = caps.get("methods")
Expand Down Expand Up @@ -127,11 +115,11 @@ def _handle_job_launched(self, event_type: str, fl_ctx: FLContext):

def _handle_job_done(self, event_type: str, fl_ctx: FLContext):
self.logger.info(f"handling event {event_type}")
job_meta = fl_ctx.get_prop(FLContextKey.JOB_META)
if not job_meta:
self.logger.error(f"missing {FLContextKey.JOB_META} from fl_ctx for event {event_type}")
job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID)
if not job_id:
self.logger.error(f"missing {FLContextKey.CURRENT_JOB_ID} from fl_ctx for event {event_type}")
else:
self._remove_job(job_meta)
self._remove_job(job_id)

def _handle_edge_job_request(self, event_type: str, fl_ctx: FLContext):
self.logger.info(f"handling event {event_type}")
Expand Down
4 changes: 2 additions & 2 deletions nvflare/edge/widgets/etg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.signal import Signal
from nvflare.edge.constants import EdgeContextKey, EdgeProtoKey
from nvflare.edge.constants import EdgeContextKey
from nvflare.edge.constants import EdgeEventType as EdgeEventType
from nvflare.edge.constants import Status
from nvflare.edge.constants import EdgeProtoKey, Status
from nvflare.widgets.widget import Widget


Expand Down
2 changes: 1 addition & 1 deletion nvflare/fuel/utils/fobs/decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def supported_type(self) -> Type[T]:
def decompose(self, target: T, manager: DatumManager = None) -> Any:
data = {}

if hasattr(target, '__dict__'):
if hasattr(target, "__dict__"):
data[DATA_CONTENT] = vars(target)

if isinstance(target, dict):
Expand Down

0 comments on commit e8de762

Please sign in to comment.