Skip to content

Commit

Permalink
Support heterogeneous tensor types for vectorized elementwise kernel.
Browse files Browse the repository at this point in the history
This patch includes changes from pytorch#147527 with
an extension to support multiple input tensor types.
  • Loading branch information
carlobertolli committed Mar 7, 2025
1 parent 5eaa466 commit 4194409
Show file tree
Hide file tree
Showing 3 changed files with 546 additions and 20 deletions.
357 changes: 357 additions & 0 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@
#define ASSERT_HOST_DEVICE_LAMBDA(type)
#endif

namespace vectorized_templated_config {
constexpr int num_threads() {
return 256;
}

constexpr int thread_work_size() {
return 64;
}

constexpr int block_work_size() {
return thread_work_size() * num_threads();
}
} // namespace vectorized_templated_config

namespace at {
namespace native {

Expand Down Expand Up @@ -147,6 +161,142 @@ static inline void launch_vectorized_kernel(
}
}

#ifdef USE_ROCM
template <
int vec_size,
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t,
typename OutputType,
typename... InputTypes>
C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads())
__global__ void vectorized_templated_elementwise_kernel(
int N,
func_t f,
array_t data,
inp_calc_t inp_calc,
out_calc_t out_calc,
loader_t loader,
storer_t storer) {
using traits = function_traits<func_t>;
int remaining =
N - vectorized_templated_config::block_work_size() * blockIdx.x;
if (remaining <
vectorized_templated_config::block_work_size()) { // if this block handles
// the reminder,
// just do a naive unrolled loop
auto policy = memory::policies::unroll_base<
vectorized_templated_config::thread_work_size(),
vectorized_templated_config::num_threads(),
vectorized_templated_config::block_work_size(),
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t>(data, remaining, inp_calc, out_calc, loader, storer);
templated_elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
templated_elementwise_kernel_helper<
vectorized_templated_config::thread_work_size()>(
f,
memory::policies::vectorized_templated<
vectorized_templated_config::thread_work_size(),
vectorized_templated_config::num_threads(),
vectorized_templated_config::block_work_size(),
vec_size,
array_t,
OutputType,
InputTypes...>(data));
}
}

// This function assume trivial 1d and supports template specialization
// to avoid dynamic casting.
// Input vectorization size is based on runtime information, i.e.
// the actual data types of the input and output tensor and cannot
// be determined using the functor type, as in regular non-templated
// vectorized kernels. The caller is in charge of selecting the correct input
// vectorization length.
template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t,
typename OutputType,
typename... InputTypes>
static inline void launch_vectorized_templated_kernel(
int64_t N,
const func_t& f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) /
vectorized_templated_config::block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
switch (vec_size) {
case 8:
vectorized_templated_elementwise_kernel<
8,
func_t,
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t,
OutputType,
InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 4:
vectorized_templated_elementwise_kernel<
4,
func_t,
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t,
OutputType,
InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
vectorized_templated_elementwise_kernel<
2,
func_t,
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t,
OutputType,
InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
// vector size 1 is not handled as part of vectorize_templated kernel
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
}
}
#endif

template <
typename func_t,
typename array_t,
Expand Down Expand Up @@ -425,6 +575,104 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
#endif
}

#ifdef USE_ROCM
namespace {
// Static functor type checker for binary functors with
// float as the type of both parameters.
template <
typename TupleLike,
typename FirstParamTy,
typename SecondParamTy,
size_t arity,
size_t arg_num = 0>
struct check_binary_functor_types_for_specialization {
constexpr static inline bool check() {
bool current = false;
if constexpr (arity != 2)
return false;
if constexpr (arg_num == 0) {
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType>)
return check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arg_num + 1>::check();
} else if constexpr (arg_num == 1) {
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType2>)
return check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arg_num + 1>::check();
}
return false;
}
};

// Bottom case: if we got this far, assume correct type matching except
// when there are no arguments (arity == 0).
template <
typename TupleLike,
typename FirstParamTy,
typename SecondParamTy,
size_t arity>
struct check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arity> {
constexpr static inline bool check() {
if constexpr (arity != 0)
return true;
return false;
}
};

template <typename TupleLike, typename FirstParamTy, typename SecondParamTy>
struct check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
0,
0> {
constexpr static inline bool check() {
return false;
}
};

// The following is a list of type specializations for vectorized_templated
// elementwise kernel. It refers to the first and second runtime types of the
// arguments of a binary functor.
constexpr int number_of_binary_specializations = 4;
const std::
array<std::array<c10::ScalarType, 2>, number_of_binary_specializations>
rt_binary_specializations = {
{{c10::CppTypeToScalarType<float>::value,
c10::CppTypeToScalarType<BFloat16>::value},
{c10::CppTypeToScalarType<BFloat16>::value,
c10::CppTypeToScalarType<float>::value},
{c10::CppTypeToScalarType<float>::value,
c10::CppTypeToScalarType<Half>::value},
{c10::CppTypeToScalarType<Half>::value,
c10::CppTypeToScalarType<float>::value}}};

bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
if (iter.ninputs() != 2)
return false;
for (int i = 0; i < 4; i++)
if (iter.input_dtype(0) == rt_binary_specializations[i][0] &&
iter.input_dtype(1) == rt_binary_specializations[i][1])
return true;
return false;
}
} // namespace
#endif

template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
if (!needs_dynamic_casting<func_t>::check(iter)) {
Expand All @@ -449,6 +697,115 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {

if (contiguous) {
#ifdef USE_ROCM
// Attempt to call specialized vectorized elementwise kernel
// that enables interleaving.
if (check_binary_rt_types_for_specialization(iter) &&
memory::can_vectorize_up_to<func_t>(data) > 1) {
// constexpr to reduce the amount of kernels (empty) generated for
// unrolled templated elementwise and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr (
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
check_binary_functor_types_for_specialization<
func_tuple,
float,
float,
traits::arity,
/*current=*/0>::check()) {
// If we got here, we know we are in one of the specialized cases. We
// need to translate the runtime type to a statically known type. This
// is effectively hoisting to the host the switch over runtime type in
// the kernel in fetch_and_cast. Loader, storer, offset calculators are
// only needed for the reminder loop.
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithCast<traits::arity>(iter);
auto storer = memory::StoreWithCast<1>(iter);
if (iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<BFloat16>::value)
launch_vectorized_templated_kernel<
func_t,
at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator),
decltype(output_offset_calculator),
decltype(loader),
decltype(storer),
float,
float,
BFloat16>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
else if (
iter.input_dtype(0) == c10::CppTypeToScalarType<BFloat16>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
launch_vectorized_templated_kernel<
func_t,
at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator),
decltype(output_offset_calculator),
decltype(loader),
decltype(storer),
float,
BFloat16,
float>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
else if (
iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<Half>::value)
launch_vectorized_templated_kernel<
func_t,
at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator),
decltype(output_offset_calculator),
decltype(loader),
decltype(storer),
float,
float,
Half>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
else if (
iter.input_dtype(0) == c10::CppTypeToScalarType<Half>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
launch_vectorized_templated_kernel<
func_t,
at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator),
decltype(output_offset_calculator),
decltype(loader),
decltype(storer),
float,
Half,
float>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
else
TORCH_CHECK(false, "unreachable");
return;
}
}
at::detail::Array<ScalarType, ntensors> dtypes;
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
Expand Down
Loading

0 comments on commit 4194409

Please sign in to comment.