Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update quantization to force gpu usage for blockwise8 #3256

Merged
merged 10 commits into from
Feb 27, 2025
16 changes: 8 additions & 8 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
11 changes: 8 additions & 3 deletions nvflare/app_opt/pt/quantization/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Supported Input Data Type
# Message quantization is mainly for reducing the message that can be
# significantly large, e.g. LLMs. Thus, the supported input data types
# we consider are common ones during LLM training, including fp32, fp16, and bf16.
DATA_TYPE = [
"FLOAT64",
"FLOAT32",
"FLOAT16",
"BFLOAT16",
"UINT8",
"INT8",
]

# Supported Quantization Type to reduce the above input data types
# The quantization types are mainly for reducing the model size,
# Hence, we support 16-, 8-, and 4-bits quantization.
# Note that 8- and 4-bits quantization needs GPU support.
QUANTIZATION_TYPE = [
"FLOAT16",
"BLOCKWISE8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nvflare.app_opt.pt.quantization.constant import QUANTIZATION_TYPE


class ModelDequantizor(DXOFilter):
class ModelDequantizer(DXOFilter):
def __init__(self):
"""Filter to dequantize Shareable object to recover from quantization

Expand Down Expand Up @@ -84,17 +84,18 @@ def dequantization(
params[param_name] = values
elif quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to dequantize the values
# need GPU for general support
# extract quantization state
if quantization_type == "blockwise8":
if source_data_format == "numpy":
# first convert numpy array to tensor if numpy
quantized = torch.as_tensor(values)
absmax = torch.as_tensor(quant_state[param_name]["absmax"])
code = torch.as_tensor(quant_state[param_name]["code"])
quantized = torch.as_tensor(values).cuda()
absmax = torch.as_tensor(quant_state[param_name]["absmax"]).cuda()
code = torch.as_tensor(quant_state[param_name]["code"]).cuda()
elif source_data_format == "torch":
quantized = values
absmax = quant_state[param_name]["absmax"]
code = quant_state[param_name]["code"]
quantized = values.cuda()
absmax = quant_state[param_name]["absmax"].cuda()
code = quant_state[param_name]["code"].cuda()
# de-quanitze
dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code)
else:
Expand Down Expand Up @@ -125,6 +126,7 @@ def dequantization(
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4")
else:
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4")

if source_data_format == "numpy":
params[param_name] = dequantized.cpu().numpy()
elif source_data_format == "torch":
Expand All @@ -135,16 +137,12 @@ def dequantization(
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].astype(np.float32)
elif source_data_type == "float64":
params[param_name] = params[param_name].astype(np.float64)
elif source_data_type == "float16":
params[param_name] = params[param_name].astype(np.float16)
elif source_data_format == "torch":
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].float()
elif source_data_type == "float64":
params[param_name] = params[param_name].double()
elif source_data_type == "float16":
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE


class ModelQuantizor(DXOFilter):
class ModelQuantizer(DXOFilter):
def __init__(
self,
quantization_type="float16",
Expand Down Expand Up @@ -120,41 +120,39 @@ def quantization(self, params: dict, fl_ctx: FLContext):
elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to quantize the values
# input is a tensor, output is a tuple of (quantized tensor, quantized_state)
if self.quantization_type == "blockwise8":
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor
values_tensor = torch.as_tensor(values)
elif source_data_format == "torch":
values_tensor = values

# then quantize the tensor
# CPU has limited support for 8- and 4-bits quantization
# For general purpose, here we use GPU
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()

if self.quantization_type == "blockwise8":
# quantize the tensor
quantized, quantized_state = quantize_blockwise(values_tensor)
# add the quantization state and values, keep source data format
if source_data_format == "numpy":
quant_state[param_name]["absmax"] = quantized_state.absmax.numpy()
quant_state[param_name]["code"] = quantized_state.code.numpy()
values = quantized.numpy()
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu().numpy()
quant_state[param_name]["code"] = quantized_state.code.cpu().numpy()
values = quantized.cpu().numpy()
elif source_data_format == "torch":
quant_state[param_name]["absmax"] = quantized_state.absmax
quant_state[param_name]["code"] = quantized_state.code
values = quantized
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu()
quant_state[param_name]["code"] = quantized_state.code.cpu()
values = quantized.cpu()
n_bytes_meta += quant_state[param_name]["absmax"].nbytes
n_bytes_meta += quant_state[param_name]["code"].nbytes
else:
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()
# then quantize the tensor
if self.quantization_type == "float4":
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4")
else:
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4")
# add the quantization state and values, keep source data format
quantized_state = quantized_state.as_dict()

# prepared the message
for state_name, state in quantized_state.items():
if isinstance(state, torch.Tensor):
if source_data_format == "numpy":
Expand All @@ -171,6 +169,7 @@ def quantization(self, params: dict, fl_ctx: FLContext):
values = quantized.cpu().numpy()
elif source_data_format == "torch":
values = quantized.cpu()

params[param_name] = values
n_bytes_after += params[param_name].nbytes

Expand Down Expand Up @@ -203,8 +202,8 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
# Potentially:
# If clients talks to each other, it will also be 1-N and same rule applies
# If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# - If clients talks to each other, it will also be 1-N and same rule applies
# - If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# a deep copy of the server data should be made by filter before applying a different filter

# quantized_flag None if does not exist in meta
Expand Down
Loading