Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Secure Aggregation cleanup #1420

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/about/features_index/secure_aggregation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ TaskRunner API
-------------------------------------

OpenFL treats SecAgg as a core security feature and can be enabled for any experiment by simply modifying the plan.
**NOTE**: `pycryptodome <https://pypi.org/project/pycryptodome/>`_ is a required dependency that must be installed on the participant nodes before starting the experiment.

The following plan shows secure aggregation being enabled on `keras/mnist <https://github.com/securefederatedai/openfl/tree/develop/openfl-workspace/keras/mnist>`_ workspace by simply modifying the plan.

Expand Down
41 changes: 32 additions & 9 deletions openfl/callbacks/secure_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,9 @@
from openfl.callbacks.callback import Callback
from openfl.protocols import utils
from openfl.utilities import TensorKey
from openfl.utilities.secagg import (
calculate_shared_mask,
create_ciphertext,
create_secret_shares,
decipher_ciphertext,
generate_agreed_key,
generate_key_pair,
pseudo_random_generator,
)

logger = logging.getLogger(__name__)
__required_package = "pycryptodome"


class SecAggBootstrapping(Callback):
Expand All @@ -39,6 +31,19 @@ class SecAggBootstrapping(Callback):
It also requires the tensor-db client to be set.
"""

def __init__(self):
super().__init__()
# Check if pycryptodome is installed.
import pkg_resources

try:
pkg_resources.get_distribution(__required_package)
except pkg_resources.DistributionNotFound:
raise Exception(
f"'{__required_package}' not installed."
"This package is necessary when secure aggregation is enabled."
)

def on_experiment_begin(self, logs=None):
"""
Used to perform secure aggregation setup before experiment begins.
Expand Down Expand Up @@ -73,6 +78,8 @@ def _generate_keys(self):
secure aggregation mechanism.
5. Updates the instance parameters with the local result.
"""
from openfl.utilities.secagg import generate_key_pair

private_key1, public_key1 = generate_key_pair()
private_key2, public_key2 = generate_key_pair()

Expand Down Expand Up @@ -127,6 +134,12 @@ def _generate_ciphertexts(self, public_keys):
indices and values are lists containing public keys of the
collaborators.
"""
from openfl.utilities.secagg import (
create_ciphertext,
create_secret_shares,
generate_agreed_key,
)

logger.debug("SecAgg: Generating ciphertexts to be shared with other collaborators")
collaborator_count = len(public_keys)

Expand Down Expand Up @@ -184,6 +197,11 @@ def _decrypt_ciphertexts(self, public_keys):
public_keys (dict): A dictionary containing the public keys of the
collaborators.
"""
from openfl.utilities.secagg import (
decipher_ciphertext,
generate_agreed_key,
)

logger.debug("SecAgg: fetching addressed ciphertexts from the aggregator")

ciphertexts = self._fetch_from_aggregator("ciphertexts")
Expand Down Expand Up @@ -213,6 +231,11 @@ def _generate_masks(self):
Use the private seed and agreed keys to calculate the masks to be
added to the gradients.
"""
from openfl.utilities.secagg import (
calculate_shared_mask,
pseudo_random_generator,
)

private_mask = pseudo_random_generator(self.params.get("private_seed"))
shared_mask = calculate_shared_mask(self.params.get("agreed_keys"))

Expand Down
58 changes: 6 additions & 52 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from openfl.protocols import base_pb2, utils
from openfl.protocols.base_pb2 import NamedTensor
from openfl.utilities import TaskResultKey, TensorKey, change_tags
from openfl.utilities.secagg.setup import Setup as secagg_setup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -209,7 +208,9 @@ def __init__(
self.collaborator_tensor_results = {} # {TensorKey: nparray}}
self._secure_aggregation_enabled = secure_aggregation
if self._secure_aggregation_enabled:
self.secagg = secagg_setup(self.uuid, self.authorized_cols, self.tensor_db)
from openfl.utilities.secagg.bootstrap import SecAggSetup

self.secagg = SecAggSetup(self.uuid, self.authorized_cols, self.tensor_db)

# Callbacks
self.callbacks = callbacks_module.CallbackList(
Expand Down Expand Up @@ -754,7 +755,9 @@ def send_local_task_results(
"""
# Check if secure aggregation is enabled.
if self._secure_aggregation_enabled:
secagg_setup = self._secure_aggregation_setup(collaborator_name, named_tensors)
secagg_setup = self.secagg.process_secagg_setup_tensors(
collaborator_name, named_tensors
)
# Task results processing is not required if the tensors belong to
# secure aggregation setup stage.
if secagg_setup:
Expand Down Expand Up @@ -903,13 +906,6 @@ def _process_named_tensor(self, named_tensor, collaborator_name):
tuple(named_tensor.tags),
)
tensor_name, origin, round_number, report, tags = tensor_key
# Secure aggregation setup stage key
if "secagg" in tags:
nparray = json.loads(raw_bytes)
self.tensor_db.cache_tensor({tensor_key: nparray})
logger.debug("Created TensorKey: %s", tensor_key)

return tensor_key, nparray

assert "compressed" in tags or "lossy_compressed" in tags, (
f"Named tensor {tensor_key} is not compressed"
Expand Down Expand Up @@ -1259,45 +1255,3 @@ def stop(self, failed_collaborator: str = None) -> None:
collaborator_name,
)
self.quit_job_sent_to.append(collaborator_name)

def _secure_aggregation_setup(self, collaborator_name, named_tensors):
"""
Set up secure aggregation for the given collaborator and named tensors.
This method processes named tensors that are part of the secure
aggregation setup stages. It saves the processed tensors to the local
tensor database and checks if all collaborators have sent their data
for the current key. If all collaborators have sent their data, it
proceeds with aggregation for the key.
Args:
collaborator_name (str): The name of the collaborator sending the
tensors.
named_tensors (list): A list of named tensors to be processed.
Returns:
bool: True if the setup is complete or if the tensor does not
belong to secure aggregation setup, otherwise waits for all
collaborators.
"""
secagg_setup = False
for named_tensor in named_tensors:
# Check if the tensor belongs to one from secure aggregation
# setup stages.
if "secagg" not in tuple(named_tensor.tags):
continue
else:
secagg_setup = True
# Process and save tensor to local tensor db.
self._process_named_tensor(named_tensor, collaborator_name)
tensor_name = named_tensor.name
# Check if all collaborators have sent their data for the
# current key.
all_collaborators_sent = self.secagg.check_tensors_received(tensor_name)
if not all_collaborators_sent:
continue
# If all collaborators have sent their data, proceed with
# aggregation for the key.
self.secagg.aggregate_tensor(tensor_name)

return secagg_setup
61 changes: 11 additions & 50 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,17 @@ def do_task(self, task, round_number) -> dict:
# with the aggregator.
unmasked_metrics = {}
if self._secure_aggregation_enabled:
unmasked_metrics = self._secure_aggregation_masking(
global_output_tensor_dict, task_name
from openfl.utilities.secagg import calulcate_masked_input_vectors

self._private_mask, self._shared_mask, unmasked_metrics = (
calulcate_masked_input_vectors(
self.collaborator_name,
self.tensor_db,
task_name,
global_output_tensor_dict,
private_mask=self._private_mask,
shared_mask=self._shared_mask,
)
)

# Save global and local output_tensor_dicts to TensorDB
Expand Down Expand Up @@ -653,51 +662,3 @@ def named_tensor_to_nparray(self, named_tensor):
self.tensor_db.cache_tensor({decompressed_tensor_key: decompressed_nparray})

return decompressed_nparray

def _secure_aggregation_masking(self, global_output_tensor_dict, task_name):
"""
Apply secure aggregation masking to the global output tensor
dictionary.

This method modifies the provided global output tensor dictionary by
applying secure aggregation masking if secure aggregation is enabled.
It fetches the private and shared masks from the tensor database and
applies them to the tensors in the global output tensor dictionary
that have the "metric" tag.

Args:
global_output_tensor_dict (dict): A dictionary where keys are
tensor keys and values are the corresponding tensors.

Returns:
None: The method modifies the global_output_tensor_dict in place.
"""
import numpy as np

# Storing the masks as class attributes to reduce the number of
# lookups in the database.
# Fetch private mask from tensor db if not already fetched.
if not self._private_mask:
self._private_mask = self.tensor_db.get_tensor_from_cache(
TensorKey("private_mask", self.collaborator_name, -1, False, ("secagg",))
)[0]
# Fetch shared mask from tensor db if not alreday fetched.
if not self._shared_mask:
self._shared_mask = self.tensor_db.get_tensor_from_cache(
TensorKey("shared_mask", self.collaborator_name, -1, False, ("secagg",))
)[0]

metrics = {}
for tensor_key in global_output_tensor_dict:
tensor_name, _, _, report, tags = tensor_key
if "metric" in tags:
if report:
# Reportable metric must be a scalar
value = float(global_output_tensor_dict[tensor_key])
metrics.update(
{f"{self.collaborator_name}/{task_name}/{tensor_name}/unmasked": value}
)
masked_metric = np.add(self._private_mask, global_output_tensor_dict[tensor_key])
global_output_tensor_dict[tensor_key] = np.add(masked_metric, self._shared_mask)

return metrics
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

from openfl.interface.aggregation_functions.weighted_average import WeightedAverage
from openfl.utilities import LocalTensor
from openfl.utilities.secagg import (
calculate_shared_mask,
pseudo_random_generator,
)


class SecureWeightedAverage(WeightedAverage):
Expand Down Expand Up @@ -86,6 +82,11 @@ def _generate_masks(self, db_iterator):
- The private masks are stored in a dictionary with the
collaborator's name as the key.
"""
from openfl.utilities.secagg import (
calculate_shared_mask,
pseudo_random_generator,
)

if self._shared_masks and self._private_masks:
return

Expand Down
1 change: 1 addition & 0 deletions openfl/utilities/secagg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from openfl.utilities.secagg.crypto import (
calculate_mask,
calculate_shared_mask,
calulcate_masked_input_vectors,
create_ciphertext,
decipher_ciphertext,
pseudo_random_generator,
Expand Down
Loading