@@ -750,17 +750,17 @@ def set_default_tensor_type(t):
750
750
def set_default_dtype (d ):
751
751
r"""
752
752
753
- Sets the default floating point dtype to :attr:`d`. Supports torch.float32
754
- and torch.float64 as inputs. Other dtypes may be accepted without complaint
755
- but are not supported and are unlikely to work as expected.
753
+ Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
754
+ as inputs. Other dtypes will cause torch to raise an exception.
756
755
757
756
When PyTorch is initialized its default floating point dtype is torch.float32,
758
757
and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
759
758
type inference. The default floating point dtype is used to:
760
759
761
- 1. Implicitly determine the default complex dtype. When the default floating point
762
- type is float32 the default complex dtype is complex64, and when the default
763
- floating point type is float64 the default complex type is complex128.
760
+ 1. Implicitly determine the default complex dtype. When the default floating type is float16,
761
+ the default complex dtype is complex32. For float32, the default complex dtype is complex64.
762
+ For float64, it is complex128. For bfloat16, an exception will be raised because
763
+ there is no corresponding complex type for bfloat16.
764
764
2. Infer the dtype for tensors constructed using Python floats or complex Python
765
765
numbers. See examples below.
766
766
3. Determine the result of type promotion between bool and integer tensors and
@@ -782,14 +782,21 @@ def set_default_dtype(d):
782
782
torch.complex64
783
783
784
784
>>> torch.set_default_dtype(torch.float64)
785
-
786
785
>>> # Python floats are now interpreted as float64
787
786
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
788
787
torch.float64
789
788
>>> # Complex Python numbers are now interpreted as complex128
790
789
>>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
791
790
torch.complex128
792
791
792
+ >>> torch.set_default_dtype(torch.float16)
793
+ >>> # Python floats are now interpreted as float16
794
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
795
+ torch.float16
796
+ >>> # Complex Python numbers are now interpreted as complex128
797
+ >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
798
+ torch.complex32
799
+
793
800
"""
794
801
_C ._set_default_dtype (d )
795
802
0 commit comments