@@ -86,84 +86,84 @@ using acc_type = typename AccumulateType<T, is_cuda>::type;
86
86
#define CUDA_ACC_TYPE (t, acc_t ) ACC_TYPE(t, acc_t , c10::DeviceType::CUDA)
87
87
#define CPU_ACC_TYPE (t, acc_t ) ACC_TYPE(t, acc_t , c10::DeviceType::CPU)
88
88
89
- MPS_ACC_TYPE (BFloat16, float );
90
- MPS_ACC_TYPE (Half, float );
91
- MPS_ACC_TYPE (Float8_e5m2, float );
92
- MPS_ACC_TYPE (Float8_e4m3fn, float );
93
- MPS_ACC_TYPE (Float8_e5m2fnuz, float );
94
- MPS_ACC_TYPE (Float8_e4m3fnuz, float );
95
- MPS_ACC_TYPE (float , float );
96
- MPS_ACC_TYPE (double , float );
97
- MPS_ACC_TYPE (int8_t , int64_t );
98
- MPS_ACC_TYPE (uint8_t , int64_t );
99
- MPS_ACC_TYPE (char , int64_t );
100
- MPS_ACC_TYPE (int16_t , int64_t );
101
- MPS_ACC_TYPE (int32_t , int64_t );
102
- MPS_ACC_TYPE (int64_t , int64_t );
103
- MPS_ACC_TYPE (bool , bool );
104
- MPS_ACC_TYPE (c10::complex<Half>, c10::complex<float >);
105
- MPS_ACC_TYPE (c10::complex<float >, c10::complex<float >);
106
- MPS_ACC_TYPE (c10::complex<double >, c10::complex<float >);
107
-
108
- XPU_ACC_TYPE (BFloat16, float );
109
- XPU_ACC_TYPE (Half, float );
110
- XPU_ACC_TYPE (Float8_e5m2, float );
111
- XPU_ACC_TYPE (Float8_e4m3fn, float );
112
- XPU_ACC_TYPE (Float8_e5m2fnuz, float );
113
- XPU_ACC_TYPE (Float8_e4m3fnuz, float );
114
- XPU_ACC_TYPE (float , float );
115
- XPU_ACC_TYPE (double , double );
116
- XPU_ACC_TYPE (int8_t , int64_t );
117
- XPU_ACC_TYPE (uint8_t , int64_t );
118
- XPU_ACC_TYPE (char , int64_t );
119
- XPU_ACC_TYPE (int16_t , int64_t );
120
- XPU_ACC_TYPE (int32_t , int64_t );
121
- XPU_ACC_TYPE (int64_t , int64_t );
122
- XPU_ACC_TYPE (bool , bool );
123
- XPU_ACC_TYPE (c10::complex<Half>, c10::complex<float >);
124
- XPU_ACC_TYPE (c10::complex<float >, c10::complex<float >);
125
- XPU_ACC_TYPE (c10::complex<double >, c10::complex<double >);
89
+ MPS_ACC_TYPE (BFloat16, float )
90
+ MPS_ACC_TYPE (Half, float )
91
+ MPS_ACC_TYPE (Float8_e5m2, float )
92
+ MPS_ACC_TYPE (Float8_e4m3fn, float )
93
+ MPS_ACC_TYPE (Float8_e5m2fnuz, float )
94
+ MPS_ACC_TYPE (Float8_e4m3fnuz, float )
95
+ MPS_ACC_TYPE (float , float )
96
+ MPS_ACC_TYPE (double , float )
97
+ MPS_ACC_TYPE (int8_t , int64_t )
98
+ MPS_ACC_TYPE (uint8_t , int64_t )
99
+ MPS_ACC_TYPE (char , int64_t )
100
+ MPS_ACC_TYPE (int16_t , int64_t )
101
+ MPS_ACC_TYPE (int32_t , int64_t )
102
+ MPS_ACC_TYPE (int64_t , int64_t )
103
+ MPS_ACC_TYPE (bool , bool )
104
+ MPS_ACC_TYPE (c10::complex<Half>, c10::complex<float >)
105
+ MPS_ACC_TYPE (c10::complex<float >, c10::complex<float >)
106
+ MPS_ACC_TYPE (c10::complex<double >, c10::complex<float >)
107
+
108
+ XPU_ACC_TYPE (BFloat16, float )
109
+ XPU_ACC_TYPE (Half, float )
110
+ XPU_ACC_TYPE (Float8_e5m2, float )
111
+ XPU_ACC_TYPE (Float8_e4m3fn, float )
112
+ XPU_ACC_TYPE (Float8_e5m2fnuz, float )
113
+ XPU_ACC_TYPE (Float8_e4m3fnuz, float )
114
+ XPU_ACC_TYPE (float , float )
115
+ XPU_ACC_TYPE (double , double )
116
+ XPU_ACC_TYPE (int8_t , int64_t )
117
+ XPU_ACC_TYPE (uint8_t , int64_t )
118
+ XPU_ACC_TYPE (char , int64_t )
119
+ XPU_ACC_TYPE (int16_t , int64_t )
120
+ XPU_ACC_TYPE (int32_t , int64_t )
121
+ XPU_ACC_TYPE (int64_t , int64_t )
122
+ XPU_ACC_TYPE (bool , bool )
123
+ XPU_ACC_TYPE (c10::complex<Half>, c10::complex<float >)
124
+ XPU_ACC_TYPE (c10::complex<float >, c10::complex<float >)
125
+ XPU_ACC_TYPE (c10::complex<double >, c10::complex<double >)
126
126
127
127
#if defined(__CUDACC__) || defined(__HIPCC__)
128
- CUDA_ACC_TYPE (half, float );
128
+ CUDA_ACC_TYPE (half, float )
129
129
#endif
130
- CUDA_ACC_TYPE (BFloat16, float );
131
- CUDA_ACC_TYPE (Half, float );
132
- CUDA_ACC_TYPE (Float8_e5m2, float );
133
- CUDA_ACC_TYPE (Float8_e4m3fn, float );
134
- CUDA_ACC_TYPE (Float8_e5m2fnuz, float );
135
- CUDA_ACC_TYPE (Float8_e4m3fnuz, float );
136
- CUDA_ACC_TYPE (float , float );
137
- CUDA_ACC_TYPE (double , double );
138
- CUDA_ACC_TYPE (int8_t , int64_t );
139
- CUDA_ACC_TYPE (uint8_t , int64_t );
140
- CUDA_ACC_TYPE (char , int64_t );
141
- CUDA_ACC_TYPE (int16_t , int64_t );
142
- CUDA_ACC_TYPE (int32_t , int64_t );
143
- CUDA_ACC_TYPE (int64_t , int64_t );
144
- CUDA_ACC_TYPE (bool , bool );
145
- CUDA_ACC_TYPE (c10::complex<Half>, c10::complex<float >);
146
- CUDA_ACC_TYPE (c10::complex<float >, c10::complex<float >);
147
- CUDA_ACC_TYPE (c10::complex<double >, c10::complex<double >);
148
-
149
- CPU_ACC_TYPE (BFloat16, float );
150
- CPU_ACC_TYPE (Half, float );
151
- CPU_ACC_TYPE (Float8_e5m2, float );
152
- CPU_ACC_TYPE (Float8_e4m3fn, float );
153
- CPU_ACC_TYPE (Float8_e5m2fnuz, float );
154
- CPU_ACC_TYPE (Float8_e4m3fnuz, float );
155
- CPU_ACC_TYPE (float , double );
156
- CPU_ACC_TYPE (double , double );
157
- CPU_ACC_TYPE (int8_t , int64_t );
158
- CPU_ACC_TYPE (uint8_t , int64_t );
159
- CPU_ACC_TYPE (char , int64_t );
160
- CPU_ACC_TYPE (int16_t , int64_t );
161
- CPU_ACC_TYPE (int32_t , int64_t );
162
- CPU_ACC_TYPE (int64_t , int64_t );
163
- CPU_ACC_TYPE (bool , bool );
164
- CPU_ACC_TYPE (c10::complex<Half>, c10::complex<float >);
165
- CPU_ACC_TYPE (c10::complex<float >, c10::complex<double >);
166
- CPU_ACC_TYPE (c10::complex<double >, c10::complex<double >);
130
+ CUDA_ACC_TYPE (BFloat16, float )
131
+ CUDA_ACC_TYPE (Half, float )
132
+ CUDA_ACC_TYPE (Float8_e5m2, float )
133
+ CUDA_ACC_TYPE (Float8_e4m3fn, float )
134
+ CUDA_ACC_TYPE (Float8_e5m2fnuz, float )
135
+ CUDA_ACC_TYPE (Float8_e4m3fnuz, float )
136
+ CUDA_ACC_TYPE (float , float )
137
+ CUDA_ACC_TYPE (double , double )
138
+ CUDA_ACC_TYPE (int8_t , int64_t )
139
+ CUDA_ACC_TYPE (uint8_t , int64_t )
140
+ CUDA_ACC_TYPE (char , int64_t )
141
+ CUDA_ACC_TYPE (int16_t , int64_t )
142
+ CUDA_ACC_TYPE (int32_t , int64_t )
143
+ CUDA_ACC_TYPE (int64_t , int64_t )
144
+ CUDA_ACC_TYPE (bool , bool )
145
+ CUDA_ACC_TYPE (c10::complex<Half>, c10::complex<float >)
146
+ CUDA_ACC_TYPE (c10::complex<float >, c10::complex<float >)
147
+ CUDA_ACC_TYPE (c10::complex<double >, c10::complex<double >)
148
+
149
+ CPU_ACC_TYPE (BFloat16, float )
150
+ CPU_ACC_TYPE (Half, float )
151
+ CPU_ACC_TYPE (Float8_e5m2, float )
152
+ CPU_ACC_TYPE (Float8_e4m3fn, float )
153
+ CPU_ACC_TYPE (Float8_e5m2fnuz, float )
154
+ CPU_ACC_TYPE (Float8_e4m3fnuz, float )
155
+ CPU_ACC_TYPE (float , double )
156
+ CPU_ACC_TYPE (double , double )
157
+ CPU_ACC_TYPE (int8_t , int64_t )
158
+ CPU_ACC_TYPE (uint8_t , int64_t )
159
+ CPU_ACC_TYPE (char , int64_t )
160
+ CPU_ACC_TYPE (int16_t , int64_t )
161
+ CPU_ACC_TYPE (int32_t , int64_t )
162
+ CPU_ACC_TYPE (int64_t , int64_t )
163
+ CPU_ACC_TYPE (bool , bool )
164
+ CPU_ACC_TYPE (c10::complex<Half>, c10::complex<float >)
165
+ CPU_ACC_TYPE (c10::complex<float >, c10::complex<double >)
166
+ CPU_ACC_TYPE (c10::complex<double >, c10::complex<double >)
167
167
168
168
TORCH_API c10::ScalarType toAccumulateType (
169
169
c10::ScalarType type,
0 commit comments