Skip to content

Commit

Permalink
formats
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Feb 23, 2025
1 parent c8379c8 commit 98cfdc2
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class FedAvgNewtonRaphson(BaseFedAvg):
def __init__(self, damping_factor, epsilon=1.0, *args, **kwargs):
super().__init__(*args, **kwargs)

"""
Init function for FedAvgNewtonRaphson.
Expand Down Expand Up @@ -167,4 +167,4 @@ def update_model(self, model, model_update, replace_meta=True) -> FLModel:

model.metrics = model_update.metrics
# model.params[NPConstants.NUMPY_KEY] += model_update.params["newton_raphson_updates"]
model.params["weights"] += model_update.params["newton_raphson_updates"]
model.params["weights"] += model_update.params["newton_raphson_updates"]
2 changes: 0 additions & 2 deletions nvflare/app_common/statistics/numeric_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.app_opt.statistics.quantile_stats import get_quantiles
from nvflare.fuel.utils.log_utils import get_module_logger


T = TypeVar("T")

Expand Down Expand Up @@ -49,7 +48,6 @@ def get_global_stats(
ordered_target_metrics = StC.ordered_statistics[metric_task]
ordered_metrics = [metric for metric in ordered_target_metrics if metric in client_metrics]


for metric in ordered_metrics:
if metric not in global_metrics:
global_metrics[metric] = {}
Expand Down
3 changes: 1 addition & 2 deletions nvflare/app_opt/statistics/df/df_core_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def quantiles(self, dataset_name: str, feature_name: str, percents: List) -> Dic
if not flag:
results[StatisticsConstants.STATS_QUANTILE] = {}
return results

df = self.data[dataset_name]
data = df[feature_name]
max_bin = self.max_bin if self.max_bin else round(sqrt(len(data)))
Expand All @@ -116,4 +116,3 @@ def quantiles(self, dataset_name: str, feature_name: str, percents: List) -> Dic
# Extract the Q-Digest into a dictionary
results[StatisticsConstants.STATS_DIGEST_COORD] = digest.to_dict()
return results

10 changes: 5 additions & 5 deletions nvflare/app_opt/statistics/quantile_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

from typing import Dict

from nvflare.app_common.app_constant import StatisticsConstants as StC
from nvflare.fuel.utils.log_utils import get_module_logger


try:
from fastdigest import TDigest

TDIGEST_AVAILABLE = True
except ImportError:
TDIGEST_AVAILABLE = False
Expand All @@ -28,15 +29,15 @@


def get_quantiles(stats: Dict, statistic_configs: Dict, precision: int):

logger.info(f"get_quantiles: stats: {TDIGEST_AVAILABLE=}")

if not TDIGEST_AVAILABLE:
return {}

global_digest = {}
for client_name in stats:
global_digest = merge_quantiles(stats[client_name],global_digest)
global_digest = merge_quantiles(stats[client_name], global_digest)

quantile_config = statistic_configs.get(StC.STATS_QUANTILE)
return compute_quantiles(global_digest, quantile_config, precision)
Expand Down Expand Up @@ -87,12 +88,11 @@ def compute_quantiles(g_digest: dict, quantile_config: Dict, precision: int) ->
feature_metrics = g_digest[ds_name]
for feature_name in feature_metrics:
digest = feature_metrics[feature_name]
percentiles = get_target_quantiles(quantile_config,feature_name)
percentiles = get_target_quantiles(quantile_config, feature_name)
quantile_values = {}
for percentile in percentiles:
quantile_values[percentile] = round(digest.quantile(percentile), precision)

g_ds_metrics[ds_name][feature_name] = quantile_values

return g_ds_metrics

1 change: 1 addition & 0 deletions tests/unit_test/app_common/statistics/quantile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

try:
from fastdigest import TDigest

TDIGEST_AVAILABLE = True
except ImportError:
TDIGEST_AVAILABLE = False
Expand Down

0 comments on commit 98cfdc2

Please sign in to comment.