Skip to content

Commit

Permalink
Merge pull request #5 from camille-004/cd/fedavg-weights
Browse files Browse the repository at this point in the history
fix: update fedavg to properly weight clients
  • Loading branch information
camille-004 authored Dec 8, 2024
2 parents a30e8e3 + ac2d50c commit 6533fd1
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 37 deletions.
28 changes: 22 additions & 6 deletions examples/mnist/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,27 @@
from nanofed.trainer import TorchTrainer, TrainingConfig


async def run_client(client_id: str, server_url: str, data_dir: Path):
"""Run a federated client."""
async def run_client(client_id: str, server_url: str, data_dir: Path, num_samples: int):
"""Run a federated client.
Parameters
----------
client_id : str
Unique identifier for this client
server_url : str
URL of the FL server
data_dir : Path
Directory containing the dataset
num_samples : int
Number of samples for this client's dataset
"""
# Calculate subset fraction based on desired number of samples
# MNIST train set has 60000 samples
subset_fraction = num_samples / 60000

# Prepare the client's local dataset
train_loader = load_mnist_data(
data_dir=data_dir, batch_size=64, train=True
data_dir=data_dir, batch_size=64, train=True, subset_fraction=subset_fraction
)

# Client training configuration
Expand Down Expand Up @@ -112,9 +128,9 @@ async def main():
# Run the coordinator and clients concurrently
await asyncio.gather(
coordinate(coordinator),
run_client("client_1", "http://0.0.0.0:8080", data_dir),
run_client("client_2", "http://0.0.0.0:8080", data_dir),
run_client("client_3", "http://0.0.0.0:8080", data_dir),
run_client("client_1", "http://0.0.0.0:8080", data_dir, num_samples=12000),
run_client("client_2", "http://0.0.0.0:8080", data_dir, num_samples=8000),
run_client("client_3", "http://0.0.0.0:8080", data_dir, num_samples=4000),
)


Expand Down
4 changes: 2 additions & 2 deletions nanofed/communication/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _get_url(self, endpoint: str) -> str:
@log_exec
async def fetch_global_model(self) -> tuple[dict[str, torch.Tensor], int]:
"""Fetch current global model from server."""
with self._logger.context("client.http", self._client_id):
with self._logger.context("client.http"):
if self._session is None:
raise NanoFedError("Client session not initialized")

Expand Down Expand Up @@ -160,7 +160,7 @@ async def submit_update(
self, model: ModelProtocol, metrics: dict[str, float]
) -> bool:
"""Submit model udpate to server."""
with self._logger.context("client.http", self._client_id):
with self._logger.context("client.http"):
if self._session is None:
raise NanoFedError("Client session not initialized")

Expand Down
11 changes: 11 additions & 0 deletions nanofed/orchestration/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,20 @@ async def train_round(self) -> RoundMetrics:
for update in self._server._updates.values()
]

weights = self._aggregator._compute_weights(client_updates)

client_weights = {
update["client_id"]: weight
for update, weight in zip(client_updates, weights)
}

client_metrics = [
{
"client_id": update.get("client_id"),
"metrics": update.get("metrics", {}),
"weight": client_weights[
str(update.get("client_id", ""))
],
}
for update in client_updates
]
Expand All @@ -253,6 +263,7 @@ async def train_round(self) -> RoundMetrics:
"round": self._current_round,
"num_clients": len(client_updates),
"client_metrics": client_metrics,
"client_weights": client_weights,
},
metrics=aggregation_result.metrics,
)
Expand Down
26 changes: 18 additions & 8 deletions nanofed/server/aggregator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ def __init__(self) -> None:
def current_round(self) -> int:
return self._current_round

def _compute_weights(self, num_clients: int) -> list[float]:
"""Compute aggregation weights for clients."""
if num_clients not in self._weights_cache:
self._weights_cache[num_clients] = [
1.0 / num_clients
] * num_clients
return self._weights_cache[num_clients]

def _validate_updates(self, updates: Sequence[ModelUpdate]) -> None:
"""Validate model updates before aggregation."""
if not updates:
Expand All @@ -64,3 +56,21 @@ def aggregate(
) -> AggregationResult[T]:
"""Aggregate model updates."""
pass

@abstractmethod
def _compute_weights(self, updates: Sequence[ModelUpdate]) -> list[float]:
"""Compute aggregation weights for clients.
Each aggregation strategy should implement its own weighting scheme.
Parameters
----------
updates : Sequence[ModelUpdate]
Sequence of model updates from clients.
Returns
-------
list[float]
List of weights, one per client update.
"""
pass
48 changes: 39 additions & 9 deletions nanofed/server/aggregator/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def aggregate(
"""Aggregate updates using FedAvg algorithm."""
self._validate_updates(updates)

weights = self._compute_weights(len(updates))
weights = self._compute_weights(updates)
state_agg: dict[str, torch.Tensor] = {}

for key, value in updates[0]["model_state"].items():
Expand All @@ -65,7 +65,7 @@ def aggregate(
# Update global model
model.load_state_dict(state_agg)

avg_metrics = self._aggregate_metrics(updates)
avg_metrics = self._aggregate_metrics(updates, weights)

self._current_round += 1

Expand All @@ -78,18 +78,48 @@ def aggregate(
)

def _aggregate_metrics(
self, updates: Sequence[ModelUpdate]
self, updates: Sequence[ModelUpdate], weights: list[float]
) -> dict[str, float]:
"""Aggregate metrics from all updates."""
all_metrics: dict[str, list[float]] = {}
# (value, weight) pairs
all_metrics: dict[str, list[tuple[float, float]]] = {}

for update in updates:
for update, weight in zip(updates, weights):
for key, value in update["metrics"].items():
if isinstance(value, (int, float)):
all_metrics.setdefault(key, []).append(value)
if key not in all_metrics:
all_metrics[key] = []
all_metrics[key].append((float(value), weight))

return {
key: sum(values) / len(values)
for key, values in all_metrics.items()
if values
key: sum(val * w for val, w in value_pairs)
/ sum(w for _, w in value_pairs)
for key, value_pairs in all_metrics.items()
if value_pairs
}

def _compute_weights(self, updates: Sequence[ModelUpdate]) -> list[float]:
# In FedAvg, each client's weight is proportional to its local dataset
# size:
# w_k = n_k / n where n_k is client k's dataset size and n is total
# samples.
sample_counts = []
for update in updates:
num_samples = update["metrics"].get("num_samples") or update[
"metrics"
].get("samples_processed")
if num_samples is None:
self._logger.warning(
f"Client {update['client_id']} did not report sample "
f"count. Using 1.0"
)
num_samples = 1.0
sample_counts.append(num_samples)

total_samples = sum(sample_counts)
weights = [count / total_samples for count in sample_counts]

self._logger.debug(f"Client sample counts: {sample_counts}")
self._logger.debug(f"Computed weights: {weights}")

return weights
132 changes: 120 additions & 12 deletions tests/unit/server/aggregator/test_fedavg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import UTC, datetime

import pytest
import torch

from nanofed.core.exceptions import AggregationError
from nanofed.core.interfaces import ModelProtocol
from nanofed.core.types import ModelUpdate
from nanofed.server.aggregator.fedavg import FedAvgAggregator


Expand All @@ -20,40 +23,145 @@ def test_fedavg_aggregate_success():
aggregator = FedAvgAggregator()

updates = [
{
"model_state": {
ModelUpdate(
client_id="client1",
round_number=1,
model_state={
"fc.weight": [[1.0, 1.0], [1.0, 1.0]],
"fc.bias": [0.5, 0.5],
},
"round_number": 1,
"metrics": {"loss": 0.1, "accuracy": 0.9},
},
{
"model_state": {
metrics={
"loss": 0.1,
"accuracy": 0.9,
"samples_processed": 1000, # First client has 1000 samples
},
timestamp=datetime.now(UTC),
),
ModelUpdate(
client_id="client2",
round_number=1,
model_state={
"fc.weight": [[2.0, 2.0], [2.0, 2.0]],
"fc.bias": [1.0, 1.0],
},
"round_number": 1,
"metrics": {"loss": 0.2, "accuracy": 0.8},
},
metrics={
"loss": 0.2,
"accuracy": 0.8,
"samples_processed": 2000, # Second client has 2000 samples
},
timestamp=datetime.now(UTC),
),
]

result = aggregator.aggregate(model, updates)

assert pytest.approx(result.metrics["loss"], rel=1e-5) == (
0.1 * 1 / 3 + 0.2 * 2 / 3
)
assert pytest.approx(result.metrics["accuracy"], rel=1e-5) == (
0.9 * 1 / 3 + 0.8 * 2 / 3
)

# Check model parameters are weighted correctly
state_dict = model.state_dict()
assert torch.allclose(
state_dict["fc.weight"],
torch.tensor([[1.667, 1.667], [1.667, 1.667]], dtype=torch.float32),
rtol=1e-3,
)
assert torch.allclose(
state_dict["fc.bias"],
torch.tensor([0.833, 0.833], dtype=torch.float32),
rtol=1e-3,
)


def test_fedavg_aggregate_missing_samples():
"""Test FedAvg when sample counts are missing."""
model = DummyModel()
aggregator = FedAvgAggregator()

updates = [
ModelUpdate(
client_id="client1",
round_number=1,
model_state={
"fc.weight": [[1.0, 1.0], [1.0, 1.0]],
"fc.bias": [0.5, 0.5],
},
metrics={"loss": 0.1, "accuracy": 0.9}, # No samples_processed
timestamp=datetime.now(UTC),
),
ModelUpdate(
client_id="client2",
round_number=1,
model_state={
"fc.weight": [[2.0, 2.0], [2.0, 2.0]],
"fc.bias": [1.0, 1.0],
},
metrics={"loss": 0.2, "accuracy": 0.8}, # No samples_processed
timestamp=datetime.now(UTC),
),
]

result = aggregator.aggregate(model, updates)

# With missing sample counts, should default to equal weights
assert pytest.approx(result.metrics["loss"], 0.001) == 0.15
assert pytest.approx(result.metrics["accuracy"], 0.001) == 0.85


def test_fedavg_aggregate_validation_error():
"""Test FedAvg validation for different round numbers."""
model = DummyModel()
aggregator = FedAvgAggregator()

updates = [
{"model_state": {}, "round_number": 1, "metrics": {}},
{"model_state": {}, "round_number": 2, "metrics": {}},
ModelUpdate(
client_id="client1",
round_number=1,
model_state={},
metrics={},
timestamp=datetime.now(UTC),
),
ModelUpdate(
client_id="client2",
round_number=2,
model_state={},
metrics={},
timestamp=datetime.now(UTC),
),
]

with pytest.raises(
AggregationError, match="Updates from different rounds: {1, 2}"
):
aggregator.aggregate(model, updates)


def test_fedavg_aggregate_different_architectures():
"""Test FedAvg validation for different model architectures."""
model = DummyModel()
aggregator = FedAvgAggregator()

updates = [
ModelUpdate(
client_id="client1",
round_number=1,
model_state={"layer1": [1.0]},
metrics={},
timestamp=datetime.now(UTC),
),
ModelUpdate(
client_id="client2",
round_number=1,
model_state={"layer2": [1.0]},
metrics={},
timestamp=datetime.now(UTC),
),
]

with pytest.raises(
AggregationError, match="Inconsistent model architectures in updates"
):
aggregator.aggregate(model, updates)

0 comments on commit 6533fd1

Please sign in to comment.