diff --git a/nvflare/app_opt/pt/numpy_params_converter.py b/nvflare/app_opt/pt/numpy_params_converter.py index 503da8e284..10bdda7f9c 100644 --- a/nvflare/app_opt/pt/numpy_params_converter.py +++ b/nvflare/app_opt/pt/numpy_params_converter.py @@ -48,7 +48,15 @@ def convert(self, params: Dict, fl_ctx) -> Dict: exclude_vars = {} for k, v in params.items(): if isinstance(v, torch.Tensor): - return_tensors[k] = v.cpu().numpy() + # Check type of tensor and convert to numpy + data_type = str(v.dtype).split(".")[1] + # Numpy does not support bfloat16, give error + if data_type == "bfloat16": + raise ValueError( + f"Unsupported data type for numpy transmission: {data_type}, please use pytorch exchange format or convert params to a supported data type (fp32, fp16, etc.)" + ) + else: + return_tensors[k] = v.cpu().numpy() tensor_shapes[k] = v.shape else: exclude_vars[k] = v