Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove meta, enable uint8 quantization, and update server filter behavior #3222

Merged
merged 17 commits into from
Feb 22, 2025
Merged
10 changes: 4 additions & 6 deletions nvflare/app_opt/pt/quantization/dequantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,18 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None)
if quantization_type.upper() not in QUANTIZATION_TYPE:
raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}")

source_datatype = dxo.get_meta_prop(key="source_datatype", default=None)
dequantized_params = self.dequantization(
params=dxo.data,
quant_state=dxo.meta["quant_state"],
quantization_type=quantization_type,
source_datatype=dxo.meta["source_datatype"],
source_datatype=source_datatype,
fl_ctx=fl_ctx,
)
# Compose new DXO with dequantized data
dxo.data = dequantized_params
dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM)
dxo.remove_meta_props("quant_state")
dxo.remove_meta_props("source_datatype")
dxo.remove_meta_props([MetaKey.PROCESSED_ALGORITHM, "quant_state", "source_datatype", "quantized_flag"])
dxo.update_shareable(shareable)
self.log_info(fl_ctx, "Dequantized back")
self.log_info(fl_ctx, f"Dequantized back to {source_datatype}")

return dxo
36 changes: 28 additions & 8 deletions nvflare/app_opt/pt/quantization/quantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,33 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
"""

self.log_info(fl_ctx, "Running quantization...")
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
self.log_info(fl_ctx, f"Quantized to {self.quantization_type}")

# for already quantized message, skip quantization
# The reason in this current example:
# server job in this case is 1-N communication with identical quantization operation
# the first communication to client will apply quantization and change the data on the server
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
# Potentially:
# If clients talks to each other, it will also be 1-N and same rule applies
# If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# a deep copy of the server data should be made by filter before applying a different filter

# quantized_flag None if does not exist in meta
quantized_flag = dxo.get_meta_prop("quantized_flag")
if quantized_flag:
self.log_info(fl_ctx, "Already quantized, skip quantization")
new_dxo = dxo
else:
# apply quantization
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
new_dxo.set_meta_prop(key="quantized_flag", value=True)
self.log_info(fl_ctx, f"Quantized from {source_datatype} to {self.quantization_type}")

return new_dxo
Loading