From 90448b734b7a79711d194f393ae23f5d069d8def Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 3 Mar 2025 10:43:30 -0500 Subject: [PATCH] add error message for unsupported data type by numpy --- nvflare/app_opt/pt/numpy_params_converter.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nvflare/app_opt/pt/numpy_params_converter.py b/nvflare/app_opt/pt/numpy_params_converter.py index 503da8e284..10bdda7f9c 100644 --- a/nvflare/app_opt/pt/numpy_params_converter.py +++ b/nvflare/app_opt/pt/numpy_params_converter.py @@ -48,7 +48,15 @@ def convert(self, params: Dict, fl_ctx) -> Dict: exclude_vars = {} for k, v in params.items(): if isinstance(v, torch.Tensor): - return_tensors[k] = v.cpu().numpy() + # Check type of tensor and convert to numpy + data_type = str(v.dtype).split(".")[1] + # Numpy does not support bfloat16, give error + if data_type == "bfloat16": + raise ValueError( + f"Unsupported data type for numpy transmission: {data_type}, please use pytorch exchange format or convert params to a supported data type (fp32, fp16, etc.)" + ) + else: + return_tensors[k] = v.cpu().numpy() tensor_shapes[k] = v.shape else: exclude_vars[k] = v