Skip to content

Commit

Permalink
updates to PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Feb 21, 2025
1 parent 44531c3 commit dd0b3d9
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 187 deletions.
3 changes: 1 addition & 2 deletions examples/advanced/sklearn-kmeans/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,4 @@ The resulting curve for `homogeneity_score` is
It can be visualized using
```commandline
tensorboard --logdir /tmp/nvflare/workspace/works/kmeans/sklearn_kmeans_uniform_3_clients
```
Note that there will be certain amount of randomness in the results.
```
Binary file modified examples/advanced/sklearn-kmeans/figs/minibatch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 2 additions & 91 deletions examples/advanced/sklearn-kmeans/kmeans_job_clientapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import argparse
import os
from enum import Enum
from typing import List

import numpy as np
from src.kmeans_assembler import KMeansAssembler
from utils.split_data import split_data

from nvflare import FedJob
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
Expand All @@ -28,92 +26,6 @@
from nvflare.job_config.script_runner import ScriptRunner


class SplitMethod(Enum):
UNIFORM = "uniform"
LINEAR = "linear"
SQUARE = "square"
EXPONENTIAL = "exponential"


def get_split_ratios(site_num: int, split_method: SplitMethod):
if split_method == SplitMethod.UNIFORM:
ratio_vec = np.ones(site_num)
elif split_method == SplitMethod.LINEAR:
ratio_vec = np.linspace(1, site_num, num=site_num)
elif split_method == SplitMethod.SQUARE:
ratio_vec = np.square(np.linspace(1, site_num, num=site_num))
elif split_method == SplitMethod.EXPONENTIAL:
ratio_vec = np.exp(np.linspace(1, site_num, num=site_num))
else:
raise ValueError(f"Split method {split_method.name} not implemented!")

return ratio_vec


def split_num_proportion(n, site_num, split_method: SplitMethod) -> List[int]:
split = []
ratio_vec = get_split_ratios(site_num, split_method)
total = sum(ratio_vec)
left = n
for site in range(site_num - 1):
x = int(n * ratio_vec[site] / total)
left = left - x
split.append(x)
split.append(left)
return split


def assign_data_index_to_sites(
data_size: int,
valid_fraction: float,
num_sites: int,
split_method: SplitMethod = SplitMethod.UNIFORM,
) -> dict:
if valid_fraction > 1.0:
raise ValueError("validation percent should be less than or equal to 100% of the total data")
elif valid_fraction < 1.0:
valid_size = int(round(data_size * valid_fraction, 0))
train_size = data_size - valid_size
else:
valid_size = data_size
train_size = data_size

site_sizes = split_num_proportion(train_size, num_sites, split_method)
split_data_indices = {
"valid": {"start": 0, "end": valid_size},
}
for site in range(num_sites):
site_id = site + 1
if valid_fraction < 1.0:
idx_start = valid_size + sum(site_sizes[:site])
idx_end = valid_size + sum(site_sizes[: site + 1])
else:
idx_start = sum(site_sizes[:site])
idx_end = sum(site_sizes[: site + 1])
split_data_indices[site_id] = {"start": idx_start, "end": idx_end}

return split_data_indices


def get_file_line_count(input_path: str) -> int:
count = 0
with open(input_path, "r") as fp:
for i, _ in enumerate(fp):
count += 1
return count


def split_data(
data_path: str,
num_clients: int,
valid_frac: float,
split_method: SplitMethod = SplitMethod.UNIFORM,
):
size_total_file = get_file_line_count(data_path)
site_indices = assign_data_index_to_sites(size_total_file, valid_frac, num_clients, split_method)
return site_indices


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -171,7 +83,7 @@ def main():
split_mode = args.split_mode
valid_frac = args.valid_frac
job_name = f"sklearn_kmeans_{split_mode}_{num_clients}_clients"
train_script = "src/kmeans_trainer.py"
train_script = "src/kmeans_fl.py"

# Set the output workspace and job directories
workspace_dir = os.path.join(args.workspace_dir, job_name)
Expand Down Expand Up @@ -209,7 +121,6 @@ def main():
data_path,
num_clients,
valid_frac,
SplitMethod(split_mode),
)

for i in range(1, num_clients + 1):
Expand Down
1 change: 1 addition & 0 deletions examples/advanced/sklearn-kmeans/prepare_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ if [ -f "$DATASET_PATH" ]; then
else
python3 "${script_dir}"/utils/prepare_data.py \
--dataset_name iris \
--randomize 0 \
--out_path ${DATASET_PATH}
echo "Data loaded and saved in ${DATASET_PATH}"
fi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
3 changes: 2 additions & 1 deletion examples/advanced/sklearn-kmeans/utils/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def prepare_data(
x = dataset.data
y = dataset.target
if randomize:
np.random.seed(0)
print("Randomizing data sequence")

idx_random = np.random.permutation(len(y))
x = x[idx_random, :]
y = y[idx_random]
Expand Down
104 changes: 104 additions & 0 deletions examples/advanced/sklearn-kmeans/utils/split_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import List

import numpy as np


class SplitMethod(Enum):
UNIFORM = "uniform"
LINEAR = "linear"
SQUARE = "square"
EXPONENTIAL = "exponential"


def get_split_ratios(site_num: int, split_method: SplitMethod):
if split_method == SplitMethod.UNIFORM:
ratio_vec = np.ones(site_num)
elif split_method == SplitMethod.LINEAR:
ratio_vec = np.linspace(1, site_num, num=site_num)
elif split_method == SplitMethod.SQUARE:
ratio_vec = np.square(np.linspace(1, site_num, num=site_num))
elif split_method == SplitMethod.EXPONENTIAL:
ratio_vec = np.exp(np.linspace(1, site_num, num=site_num))
else:
raise ValueError(f"Split method {split_method.name} not implemented!")

return ratio_vec


def split_num_proportion(n, site_num, split_method: SplitMethod) -> List[int]:
split = []
ratio_vec = get_split_ratios(site_num, split_method)
total = sum(ratio_vec)
left = n
for site in range(site_num - 1):
x = int(n * ratio_vec[site] / total)
left = left - x
split.append(x)
split.append(left)
return split


def assign_data_index_to_sites(
data_size: int,
valid_fraction: float,
num_sites: int,
split_method: SplitMethod = SplitMethod.UNIFORM,
) -> dict:
if valid_fraction > 1.0:
raise ValueError("validation percent should be less than or equal to 100% of the total data")
elif valid_fraction < 1.0:
valid_size = int(round(data_size * valid_fraction, 0))
train_size = data_size - valid_size
else:
valid_size = data_size
train_size = data_size

site_sizes = split_num_proportion(train_size, num_sites, split_method)
split_data_indices = {
"valid": {"start": 0, "end": valid_size},
}
for site in range(num_sites):
site_id = site + 1
if valid_fraction < 1.0:
idx_start = valid_size + sum(site_sizes[:site])
idx_end = valid_size + sum(site_sizes[: site + 1])
else:
idx_start = sum(site_sizes[:site])
idx_end = sum(site_sizes[: site + 1])
split_data_indices[site_id] = {"start": idx_start, "end": idx_end}

return split_data_indices


def get_file_line_count(input_path: str) -> int:
count = 0
with open(input_path, "r") as fp:
for i, _ in enumerate(fp):
count += 1
return count


def split_data(
data_path: str,
num_clients: int,
valid_frac: float,
split_method: SplitMethod = SplitMethod.UNIFORM,
):
size_total_file = get_file_line_count(data_path)
site_indices = assign_data_index_to_sites(size_total_file, valid_frac, num_clients, split_method)
return site_indices
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import argparse
import os
from enum import Enum
from typing import List

import numpy as np
from src.kmeans_assembler import KMeansAssembler
from utils.split_data import split_data

from nvflare import FedJob
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
Expand All @@ -28,92 +26,6 @@
from nvflare.job_config.script_runner import ScriptRunner


class SplitMethod(Enum):
UNIFORM = "uniform"
LINEAR = "linear"
SQUARE = "square"
EXPONENTIAL = "exponential"


def get_split_ratios(site_num: int, split_method: SplitMethod):
if split_method == SplitMethod.UNIFORM:
ratio_vec = np.ones(site_num)
elif split_method == SplitMethod.LINEAR:
ratio_vec = np.linspace(1, site_num, num=site_num)
elif split_method == SplitMethod.SQUARE:
ratio_vec = np.square(np.linspace(1, site_num, num=site_num))
elif split_method == SplitMethod.EXPONENTIAL:
ratio_vec = np.exp(np.linspace(1, site_num, num=site_num))
else:
raise ValueError(f"Split method {split_method.name} not implemented!")

return ratio_vec


def split_num_proportion(n, site_num, split_method: SplitMethod) -> List[int]:
split = []
ratio_vec = get_split_ratios(site_num, split_method)
total = sum(ratio_vec)
left = n
for site in range(site_num - 1):
x = int(n * ratio_vec[site] / total)
left = left - x
split.append(x)
split.append(left)
return split


def assign_data_index_to_sites(
data_size: int,
valid_fraction: float,
num_sites: int,
split_method: SplitMethod = SplitMethod.UNIFORM,
) -> dict:
if valid_fraction > 1.0:
raise ValueError("validation percent should be less than or equal to 100% of the total data")
elif valid_fraction < 1.0:
valid_size = int(round(data_size * valid_fraction, 0))
train_size = data_size - valid_size
else:
valid_size = data_size
train_size = data_size

site_sizes = split_num_proportion(train_size, num_sites, split_method)
split_data_indices = {
"valid": {"start": 0, "end": valid_size},
}
for site in range(num_sites):
site_id = site + 1
if valid_fraction < 1.0:
idx_start = valid_size + sum(site_sizes[:site])
idx_end = valid_size + sum(site_sizes[: site + 1])
else:
idx_start = sum(site_sizes[:site])
idx_end = sum(site_sizes[: site + 1])
split_data_indices[site_id] = {"start": idx_start, "end": idx_end}

return split_data_indices


def get_file_line_count(input_path: str) -> int:
count = 0
with open(input_path, "r") as fp:
for i, _ in enumerate(fp):
count += 1
return count


def split_data(
data_path: str,
num_clients: int,
valid_frac: float,
split_method: SplitMethod = SplitMethod.UNIFORM,
):
size_total_file = get_file_line_count(data_path)
site_indices = assign_data_index_to_sites(size_total_file, valid_frac, num_clients, split_method)
return site_indices


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -171,7 +83,7 @@ def main():
split_mode = args.split_mode
valid_frac = args.valid_frac
job_name = f"sklearn_kmeans_{split_mode}_{num_clients}_clients"
train_script = "src/kmeans_trainer.py"
train_script = "src/kmeans_fl.py"

# Set the output workspace and job directories
workspace_dir = os.path.join(args.workspace_dir, job_name)
Expand Down Expand Up @@ -209,7 +121,6 @@ def main():
data_path,
num_clients,
valid_frac,
SplitMethod(split_mode),
)

for i in range(1, num_clients + 1):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit dd0b3d9

Please sign in to comment.