Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove meta, enable uint8 quantization, and update server filter behavior #3222

Merged
merged 17 commits into from
Feb 22, 2025
Merged
14 changes: 8 additions & 6 deletions nvflare/app_opt/pt/quantization/dequantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def dequantization(
params[param_name] = params[param_name].astype(np.float64)
elif source_data_type == "float16":
params[param_name] = params[param_name].astype(np.float16)
elif source_data_type == "uint8":
params[param_name] = params[param_name].astype(np.uint8)
elif source_data_format == "torch":
# convert back to original data type
if source_data_type == "float32":
Expand All @@ -149,6 +151,8 @@ def dequantization(
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
params[param_name] = params[param_name].bfloat16()
elif source_data_type == "uint8":
params[param_name] = params[param_name].byte()

n_bytes_after += params[param_name].nbytes

Expand Down Expand Up @@ -178,20 +182,18 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None)
if quantization_type.upper() not in QUANTIZATION_TYPE:
raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}")

source_datatype = dxo.get_meta_prop(key="source_datatype", default=None)
dequantized_params = self.dequantization(
params=dxo.data,
quant_state=dxo.meta["quant_state"],
quantization_type=quantization_type,
source_datatype=dxo.meta["source_datatype"],
source_datatype=source_datatype,
fl_ctx=fl_ctx,
)
# Compose new DXO with dequantized data
dxo.data = dequantized_params
dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM)
dxo.remove_meta_props("quant_state")
dxo.remove_meta_props("source_datatype")
dxo.remove_meta_props([MetaKey.PROCESSED_ALGORITHM, "quant_state", "source_datatype", "quantized"])
dxo.update_shareable(shareable)
self.log_info(fl_ctx, "Dequantized back")
self.log_info(fl_ctx, f"Dequantized back to {source_datatype}")

return dxo
44 changes: 36 additions & 8 deletions nvflare/app_opt/pt/quantization/quantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from nvflare.apis.dxo import DXO, DataKind, MetaKey
from nvflare.apis.dxo_filter import DXOFilter
from nvflare.apis.fl_constant import ProcessType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE
Expand Down Expand Up @@ -142,9 +143,15 @@ def quantization(self, params: dict, fl_ctx: FLContext):
n_bytes_meta += quant_state[param_name]["code"].nbytes
else:
if source_data_format == "numpy":
if source_data_bits == 8:
# int8 cannot be directly quantized to 4bit, need to convert to float first
values = values.astype(np.float32)
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
if source_data_bits == 8:
# int8 cannot be directly quantized to 4bit, need to convert to float first
values = values.float()
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()
# then quantize the tensor
Expand Down Expand Up @@ -195,13 +202,34 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
"""

self.log_info(fl_ctx, "Running quantization...")
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
self.log_info(fl_ctx, f"Quantized to {self.quantization_type}")

# for server job and already quantized message, skip quantization
# The reason is:
# server job in this case is 1-N communication with identical quantization operation
# the first communication to client will apply quantization and change the data on the server
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This is not needed for client job, since the client job will be 1-1 and quantization applies to each client
# The behavior will also be different if each server-client filter is different, in which case
# a deep copy of the server data should be made by filter before applying the process
filter_flag = True
process_type = fl_ctx.get_process_type()
quantized_flag = dxo.get_meta_prop("quantized")
if process_type == ProcessType.SERVER_JOB and quantized_flag:
filter_flag = False

if filter_flag:
# apply quantization
quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx)
# Compose new DXO with quantized data
# Add quant_state to the new DXO meta
new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta)
new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type)
new_dxo.set_meta_prop(key="quant_state", value=quant_state)
new_dxo.set_meta_prop(key="source_datatype", value=source_datatype)
new_dxo.set_meta_prop(key="quantized", value=True)
self.log_info(fl_ctx, f"Quantized from {source_datatype} to {self.quantization_type}")
else:
self.log_info(fl_ctx, "Skipping quantization, already quantized")
new_dxo = dxo

return new_dxo
Loading