Skip to content

Commit fc33bbf

Browse files
LamForestpytorchmergebot
authored andcommitted
better support set_default_dtype(torch.float16), update doc (pytorch#121730)
1. Fixes pytorch#121300 2. Previously, calling `torch.tensor([2j])` after `torch.set_default_dtype(torch.float16)` will cause a runtime error. This PR also fixes it and enables test. Pull Request resolved: pytorch#121730 Approved by: https://github.com/peterbell10
1 parent 8fdd812 commit fc33bbf

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

test/test_complex.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ def test_to_list(self, device, dtype):
1818
# there's no garbage value in the resultant list
1919
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
2020

21-
@dtypes(torch.float32, torch.float64)
21+
@dtypes(torch.float32, torch.float64, torch.float16)
2222
def test_dtype_inference(self, device, dtype):
2323
# issue: https://github.com/pytorch/pytorch/issues/36834
2424
with set_default_dtype(dtype):
2525
x = torch.tensor([3., 3. + 5.j], device=device)
26-
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
26+
if dtype == torch.float16:
27+
self.assertEqual(x.dtype, torch.chalf)
28+
elif dtype == torch.float32:
29+
self.assertEqual(x.dtype, torch.cfloat)
30+
else:
31+
self.assertEqual(x.dtype, torch.cdouble)
2732

2833
@dtypes(*complex_types())
2934
def test_conj_copy(self, device, dtype):

torch/__init__.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -750,17 +750,17 @@ def set_default_tensor_type(t):
750750
def set_default_dtype(d):
751751
r"""
752752
753-
Sets the default floating point dtype to :attr:`d`. Supports torch.float32
754-
and torch.float64 as inputs. Other dtypes may be accepted without complaint
755-
but are not supported and are unlikely to work as expected.
753+
Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
754+
as inputs. Other dtypes will cause torch to raise an exception.
756755
757756
When PyTorch is initialized its default floating point dtype is torch.float32,
758757
and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
759758
type inference. The default floating point dtype is used to:
760759
761-
1. Implicitly determine the default complex dtype. When the default floating point
762-
type is float32 the default complex dtype is complex64, and when the default
763-
floating point type is float64 the default complex type is complex128.
760+
1. Implicitly determine the default complex dtype. When the default floating type is float16,
761+
the default complex dtype is complex32. For float32, the default complex dtype is complex64.
762+
For float64, it is complex128. For bfloat16, an exception will be raised because
763+
there is no corresponding complex type for bfloat16.
764764
2. Infer the dtype for tensors constructed using Python floats or complex Python
765765
numbers. See examples below.
766766
3. Determine the result of type promotion between bool and integer tensors and
@@ -782,14 +782,21 @@ def set_default_dtype(d):
782782
torch.complex64
783783
784784
>>> torch.set_default_dtype(torch.float64)
785-
786785
>>> # Python floats are now interpreted as float64
787786
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
788787
torch.float64
789788
>>> # Complex Python numbers are now interpreted as complex128
790789
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
791790
torch.complex128
792791
792+
>>> torch.set_default_dtype(torch.float16)
793+
>>> # Python floats are now interpreted as float16
794+
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
795+
torch.float16
796+
>>> # Complex Python numbers are now interpreted as complex128
797+
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
798+
torch.complex32
799+
793800
"""
794801
_C._set_default_dtype(d)
795802

torch/_refs/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6296,6 +6296,8 @@ def _infer_scalar_type(obj):
62966296
return torch.cfloat
62976297
elif default_dtype is torch.double:
62986298
return torch.cdouble
6299+
elif default_dtype is torch.half:
6300+
return torch.chalf
62996301
else:
63006302
raise RuntimeError("invalid default scalar type for complex")
63016303
if isinstance(obj, torch.Tensor):

torch/csrc/utils/tensor_new.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ ScalarType infer_scalar_type(PyObject* obj) {
154154
return ScalarType::ComplexFloat;
155155
case ScalarType::Double:
156156
return ScalarType::ComplexDouble;
157+
case ScalarType::Half:
158+
return ScalarType::ComplexHalf;
157159
default:
158160
TORCH_CHECK(false, "invalid default scalar type for complex");
159161
}

0 commit comments

Comments
 (0)