Skip to content

Commit 3a0c680

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommitted
Jiterates exp2, erfc, erfinv and entr and refactors code_template.h to ATen (pytorch#71295)
Summary: Per title. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: pytorch#71295 Reviewed By: ngimel Differential Revision: D33575885 Pulled By: mruberry fbshipit-source-id: bc841b46fc0b5458a26a4d4465b18a7a54cd5a5b
1 parent d068849 commit 3a0c680

21 files changed

+183
-355
lines changed

torch/csrc/jit/frontend/code_template.h aten/src/ATen/code_template.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
#include <unordered_map>
88
#include <vector>
99

10-
namespace torch {
11-
namespace jit {
10+
namespace at { namespace jit {
1211

1312
// A template environment is a mapping from template variable names, e.g.,
1413
// identifier (corresponding to $identifier) to their expansions.
@@ -85,6 +84,7 @@ struct TemplateEnv {
8584
ss << "key not found: " << k;
8685
throw std::logic_error(ss.str());
8786
}
87+
8888
std::unordered_map<std::string, std::string> strings_;
8989
std::unordered_map<std::string, string_list> lists_;
9090
TemplateEnv* parent;
@@ -238,9 +238,9 @@ struct CodeTemplate {
238238
}
239239
std::string template_text;
240240
};
241+
241242
static inline std::string format(const std::string& fmt, TemplateEnv& env) {
242243
return CodeTemplate(fmt).format(env);
243244
}
244245

245-
} // namespace jit
246-
} // namespace torch
246+
}} // at::jit

aten/src/ATen/native/cuda/Math.cuh

+40
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,46 @@ const auto lgamma_string = jiterator_stringify(
503503
}
504504
); // lgamma_string
505505

506+
const auto exp2_string = jiterator_stringify(
507+
template <typename T>
508+
T exp2_kernel(T a) {
509+
return exp2(a);
510+
}
511+
); // exp2_string
512+
513+
const auto erfc_string = jiterator_stringify(
514+
template <typename T>
515+
T erfc_kernel(T a) {
516+
return erfc(a);
517+
}
518+
); // erfc_string
519+
520+
const auto erfinv_string = jiterator_stringify(
521+
template <typename T>
522+
T erfinv_kernel(T a) {
523+
return erfinv(a);
524+
}
525+
); // erfinv_string
526+
527+
const auto entr_string = jiterator_stringify(
528+
template <typename T>
529+
T entr(T a) {
530+
if (a != a) {
531+
return a;
532+
}
533+
534+
if (a > 0) {
535+
return -a * log(a);
536+
}
537+
538+
if (a == 0) {
539+
return 0;
540+
}
541+
542+
return NEG_INFINITY;
543+
}
544+
); // entr_string
545+
506546
const auto i0_string = jiterator_stringify(
507547
template<typename T>
508548
T chbevl(T x, const T array[], const int len) {

aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu

+73-36
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,28 @@
1919
namespace at {
2020
namespace native {
2121

22+
const char exp2_name[] = "exp2_kernel";
2223
void exp2_kernel_cuda(TensorIteratorBase& iter) {
23-
AT_DISPATCH_FLOATING_TYPES_AND2(
24-
ScalarType::Half, ScalarType::BFloat16,
25-
iter.common_dtype(), "exp2_cuda",
26-
[&]() {
27-
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
28-
return ::exp2(a);
29-
});
24+
#ifdef USE_JITERATOR
25+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "exp2_cuda", [&]() {
26+
jitted_gpu_kernel</*name=*/exp2_name,
27+
/*return_dtype=*/ scalar_t,
28+
/*common_dtype=*/ scalar_t,
29+
/*arity=*/ 1>(iter, exp2_string);
3030
});
31+
#else
32+
AT_DISPATCH_FLOATING_TYPES_AND2(
33+
ScalarType::Half, ScalarType::BFloat16,
34+
iter.common_dtype(), "exp2_cuda",
35+
[&]() {
36+
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
37+
return ::exp2(a);
38+
});
39+
});
40+
#endif
3141
}
3242

33-
namespace {
3443
const char i0_name[] = "i0";
35-
}
3644
void i0_kernel_cuda(TensorIteratorBase& iter) {
3745
#ifdef USE_JITERATOR
3846
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
@@ -74,9 +82,8 @@ void i0e_kernel_cuda(TensorIteratorBase& iter) {
7482
}
7583

7684
// See note [Jiterator]
77-
namespace {
85+
7886
const char i1_name[] = "i1";
79-
}
8087
void i1_kernel_cuda(TensorIteratorBase& iter) {
8188
#ifdef USE_JITERATOR
8289
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
@@ -189,21 +196,41 @@ void erf_kernel_cuda(TensorIteratorBase& iter) {
189196
});
190197
}
191198

199+
const char erfc_name[] = "erfc_kernel";
192200
void erfc_kernel_cuda(TensorIteratorBase& iter) {
193-
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
194-
iter.common_dtype(), "erfc_cuda", [&]() {
195-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
196-
return ::erfc(a);
197-
});
201+
#ifdef USE_JITERATOR
202+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() {
203+
jitted_gpu_kernel</*name=*/erfc_name,
204+
/*return_dtype=*/ scalar_t,
205+
/*common_dtype=*/ scalar_t,
206+
/*arity=*/ 1>(iter, erfc_string);
198207
});
208+
#else
209+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
210+
iter.common_dtype(), "erfc_cuda", [&]() {
211+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
212+
return ::erfc(a);
213+
});
214+
});
215+
#endif
199216
}
200217

218+
const char erfinv_name[] = "erfinv_kernel";
201219
void erfinv_kernel_cuda(TensorIteratorBase& iter) {
202-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
203-
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
204-
return ::erfinv(a);
220+
#ifdef USE_JITERATOR
221+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
222+
jitted_gpu_kernel</*name=*/erfinv_name,
223+
/*return_dtype=*/ scalar_t,
224+
/*common_dtype=*/ scalar_t,
225+
/*arity=*/ 1>(iter, erfinv_string);
226+
});
227+
#else
228+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
229+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
230+
return ::erfinv(a);
231+
});
205232
});
206-
});
233+
#endif
207234
}
208235

209236
const char erfcx_name[] = "erfcx";
@@ -237,24 +264,34 @@ void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length,
237264
});
238265
}
239266

267+
const char entr_name[] = "entr";
240268
void entr_kernel_cuda(TensorIteratorBase& iter) {
241-
AT_DISPATCH_FLOATING_TYPES_AND2(
242-
ScalarType::Half,
243-
ScalarType::BFloat16,
244-
iter.common_dtype(),
245-
"entr_cuda",
246-
[&]() {
247-
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
248-
if (at::_isnan(x)) {
249-
return x;
250-
} else if (x > 0) {
251-
return -x * std::log(x);
252-
} else if (x == 0) {
253-
return 0;
254-
}
255-
return static_cast<scalar_t>(-INFINITY);
256-
});
269+
#ifdef USE_JITERATOR
270+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() {
271+
jitted_gpu_kernel</*name=*/entr_name,
272+
/*return_dtype=*/ scalar_t,
273+
/*common_dtype=*/ scalar_t,
274+
/*arity=*/ 1>(iter, entr_string);
257275
});
276+
#else
277+
AT_DISPATCH_FLOATING_TYPES_AND2(
278+
ScalarType::Half,
279+
ScalarType::BFloat16,
280+
iter.common_dtype(),
281+
"entr_cuda",
282+
[&]() {
283+
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
284+
if (at::_isnan(x)) {
285+
return x;
286+
} else if (x > 0) {
287+
return -x * std::log(x);
288+
} else if (x == 0) {
289+
return 0;
290+
}
291+
return static_cast<scalar_t>(-INFINITY);
292+
});
293+
});
294+
#endif
258295
}
259296

260297
REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);

aten/src/ATen/native/cuda/jit_utils.cu

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#include <sstream>
22

3+
#include <c10/core/ScalarType.h>
4+
#include <c10/util/irange.h>
5+
#include <c10/cuda/CUDACachingAllocator.h>
36
#include <ATen/cuda/CUDAContext.h>
47
#include <ATen/cuda/detail/OffsetCalculator.cuh>
5-
#include <c10/cuda/CUDACachingAllocator.h>
8+
#include <ATen/code_template.h>
69
#include <ATen/native/cuda/jit_utils.h>
7-
#include <c10/core/ScalarType.h>
8-
#include <c10/util/irange.h>
10+
911

1012
namespace at { namespace cuda { namespace jit {
1113

@@ -582,7 +584,7 @@ std::string generate_code(
582584
BinaryFuncVariant scalar_pos,
583585
bool vectorized,
584586
int vec_size) {
585-
TemplateEnv env;
587+
at::jit::TemplateEnv env;
586588
env.s("index_type", "unsigned int");
587589
const int nInputs = nTensors - 1;
588590
env.s("nInputs", std::to_string(nInputs));
@@ -661,7 +663,7 @@ std::string generate_code(
661663
store_outputs << "s.store<" << result_type
662664
<< ">(out[j], data[0], output_offsets[0]);\n";
663665
env.s("store_outputs", store_outputs.str());
664-
static auto cuda_template = CodeTemplate(jit_common_types + jit_code_template);
666+
static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_code_template);
665667
return cuda_template.format(env);
666668
}
667669

@@ -694,7 +696,7 @@ std::string generate_code(
694696
}
695697
env.s("load_unrolled_inputs", load_unrolled_inputs.str());
696698

697-
static auto cuda_template = CodeTemplate(jit_common_types + jit_vectorized_code_template);
699+
static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_vectorized_code_template);
698700
return cuda_template.format(env);
699701
}
700702

0 commit comments

Comments
 (0)