|
19 | 19 | namespace at {
|
20 | 20 | namespace native {
|
21 | 21 |
|
| 22 | +const char exp2_name[] = "exp2_kernel"; |
22 | 23 | void exp2_kernel_cuda(TensorIteratorBase& iter) {
|
23 |
| - AT_DISPATCH_FLOATING_TYPES_AND2( |
24 |
| - ScalarType::Half, ScalarType::BFloat16, |
25 |
| - iter.common_dtype(), "exp2_cuda", |
26 |
| - [&]() { |
27 |
| - gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { |
28 |
| - return ::exp2(a); |
29 |
| - }); |
| 24 | + #ifdef USE_JITERATOR |
| 25 | + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "exp2_cuda", [&]() { |
| 26 | + jitted_gpu_kernel</*name=*/exp2_name, |
| 27 | + /*return_dtype=*/ scalar_t, |
| 28 | + /*common_dtype=*/ scalar_t, |
| 29 | + /*arity=*/ 1>(iter, exp2_string); |
30 | 30 | });
|
| 31 | + #else |
| 32 | + AT_DISPATCH_FLOATING_TYPES_AND2( |
| 33 | + ScalarType::Half, ScalarType::BFloat16, |
| 34 | + iter.common_dtype(), "exp2_cuda", |
| 35 | + [&]() { |
| 36 | + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { |
| 37 | + return ::exp2(a); |
| 38 | + }); |
| 39 | + }); |
| 40 | + #endif |
31 | 41 | }
|
32 | 42 |
|
33 |
| -namespace { |
34 | 43 | const char i0_name[] = "i0";
|
35 |
| -} |
36 | 44 | void i0_kernel_cuda(TensorIteratorBase& iter) {
|
37 | 45 | #ifdef USE_JITERATOR
|
38 | 46 | AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
|
@@ -74,9 +82,8 @@ void i0e_kernel_cuda(TensorIteratorBase& iter) {
|
74 | 82 | }
|
75 | 83 |
|
76 | 84 | // See note [Jiterator]
|
77 |
| -namespace { |
| 85 | + |
78 | 86 | const char i1_name[] = "i1";
|
79 |
| -} |
80 | 87 | void i1_kernel_cuda(TensorIteratorBase& iter) {
|
81 | 88 | #ifdef USE_JITERATOR
|
82 | 89 | AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
|
@@ -189,21 +196,41 @@ void erf_kernel_cuda(TensorIteratorBase& iter) {
|
189 | 196 | });
|
190 | 197 | }
|
191 | 198 |
|
| 199 | +const char erfc_name[] = "erfc_kernel"; |
192 | 200 | void erfc_kernel_cuda(TensorIteratorBase& iter) {
|
193 |
| - AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, |
194 |
| - iter.common_dtype(), "erfc_cuda", [&]() { |
195 |
| - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { |
196 |
| - return ::erfc(a); |
197 |
| - }); |
| 201 | + #ifdef USE_JITERATOR |
| 202 | + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() { |
| 203 | + jitted_gpu_kernel</*name=*/erfc_name, |
| 204 | + /*return_dtype=*/ scalar_t, |
| 205 | + /*common_dtype=*/ scalar_t, |
| 206 | + /*arity=*/ 1>(iter, erfc_string); |
198 | 207 | });
|
| 208 | + #else |
| 209 | + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, |
| 210 | + iter.common_dtype(), "erfc_cuda", [&]() { |
| 211 | + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { |
| 212 | + return ::erfc(a); |
| 213 | + }); |
| 214 | + }); |
| 215 | + #endif |
199 | 216 | }
|
200 | 217 |
|
| 218 | +const char erfinv_name[] = "erfinv_kernel"; |
201 | 219 | void erfinv_kernel_cuda(TensorIteratorBase& iter) {
|
202 |
| - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { |
203 |
| - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { |
204 |
| - return ::erfinv(a); |
| 220 | + #ifdef USE_JITERATOR |
| 221 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { |
| 222 | + jitted_gpu_kernel</*name=*/erfinv_name, |
| 223 | + /*return_dtype=*/ scalar_t, |
| 224 | + /*common_dtype=*/ scalar_t, |
| 225 | + /*arity=*/ 1>(iter, erfinv_string); |
| 226 | + }); |
| 227 | + #else |
| 228 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { |
| 229 | + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { |
| 230 | + return ::erfinv(a); |
| 231 | + }); |
205 | 232 | });
|
206 |
| - }); |
| 233 | + #endif |
207 | 234 | }
|
208 | 235 |
|
209 | 236 | const char erfcx_name[] = "erfcx";
|
@@ -237,24 +264,34 @@ void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length,
|
237 | 264 | });
|
238 | 265 | }
|
239 | 266 |
|
| 267 | +const char entr_name[] = "entr"; |
240 | 268 | void entr_kernel_cuda(TensorIteratorBase& iter) {
|
241 |
| - AT_DISPATCH_FLOATING_TYPES_AND2( |
242 |
| - ScalarType::Half, |
243 |
| - ScalarType::BFloat16, |
244 |
| - iter.common_dtype(), |
245 |
| - "entr_cuda", |
246 |
| - [&]() { |
247 |
| - gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t { |
248 |
| - if (at::_isnan(x)) { |
249 |
| - return x; |
250 |
| - } else if (x > 0) { |
251 |
| - return -x * std::log(x); |
252 |
| - } else if (x == 0) { |
253 |
| - return 0; |
254 |
| - } |
255 |
| - return static_cast<scalar_t>(-INFINITY); |
256 |
| - }); |
| 269 | + #ifdef USE_JITERATOR |
| 270 | + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() { |
| 271 | + jitted_gpu_kernel</*name=*/entr_name, |
| 272 | + /*return_dtype=*/ scalar_t, |
| 273 | + /*common_dtype=*/ scalar_t, |
| 274 | + /*arity=*/ 1>(iter, entr_string); |
257 | 275 | });
|
| 276 | + #else |
| 277 | + AT_DISPATCH_FLOATING_TYPES_AND2( |
| 278 | + ScalarType::Half, |
| 279 | + ScalarType::BFloat16, |
| 280 | + iter.common_dtype(), |
| 281 | + "entr_cuda", |
| 282 | + [&]() { |
| 283 | + gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t { |
| 284 | + if (at::_isnan(x)) { |
| 285 | + return x; |
| 286 | + } else if (x > 0) { |
| 287 | + return -x * std::log(x); |
| 288 | + } else if (x == 0) { |
| 289 | + return 0; |
| 290 | + } |
| 291 | + return static_cast<scalar_t>(-INFINITY); |
| 292 | + }); |
| 293 | + }); |
| 294 | + #endif |
258 | 295 | }
|
259 | 296 |
|
260 | 297 | REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
|
|
0 commit comments