Skip to content

Commit

Permalink
Support heterogeneous tensor types for vectorized elementwise kernel.
Browse files Browse the repository at this point in the history
This patch includes changes from pytorch#147527 with
an extension to support multiple input tensor types.
  • Loading branch information
carlobertolli committed Mar 7, 2025
1 parent 5eaa466 commit 34f7271
Show file tree
Hide file tree
Showing 3 changed files with 345 additions and 18 deletions.
207 changes: 207 additions & 0 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
#define ASSERT_HOST_DEVICE_LAMBDA(type)
#endif

namespace vectorized_templated_config {
constexpr int num_threads() {
return 256;
}

constexpr int thread_work_size() { return 64; }

constexpr int block_work_size() { return thread_work_size() * num_threads(); }
} // namespace vectorized_templated_config

namespace at {
namespace native {

Expand Down Expand Up @@ -147,6 +157,92 @@ static inline void launch_vectorized_kernel(
}
}

#ifdef USE_ROCM
template <int vec_size,
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t,
typename OutputType, typename... InputTypes>
C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads())
__global__ void vectorized_templated_elementwise_kernel(int N, func_t f, array_t data,
inp_calc_t inp_calc, out_calc_t out_calc, loader_t loader, storer_t storer) {
using traits = function_traits<func_t>;
int remaining = N - vectorized_templated_config::block_work_size() * blockIdx.x;
if (remaining < vectorized_templated_config::block_work_size()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto policy = memory::policies::unroll_base<
vectorized_templated_config::thread_work_size(),
vectorized_templated_config::num_threads(),
vectorized_templated_config::block_work_size(),
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t>(data, remaining, inp_calc, out_calc, loader, storer);
templated_elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
templated_elementwise_kernel_helper<vectorized_templated_config::thread_work_size()>(
f, memory::policies::vectorized_templated<vectorized_templated_config::thread_work_size(),
vectorized_templated_config::num_threads(), vectorized_templated_config::block_work_size(),
vec_size, array_t, OutputType, InputTypes...>(data));
}
}


// This function assume trivial 1d and supports template specialization
// to avoid dynamic casting.
// Input vectorization size is based on runtime information, i.e.
// the actual data types of the input and output tensor and cannot
// be determined using the functor type, as in regular non-templated
// vectorized kernels. The caller is in charge of selecting the correct input
// vectorization length.
template <typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t,
typename OutputType, typename... InputTypes>
static inline void launch_vectorized_templated_kernel(
int64_t N,
const func_t& f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / vectorized_templated_config::block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
switch (vec_size) {
case 8:
vectorized_templated_elementwise_kernel<8, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 4:
vectorized_templated_elementwise_kernel<4, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
vectorized_templated_elementwise_kernel<2, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...>
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
// vector size 1 is not handled as part of vectorize_templated kernel
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
}
}
#endif

template <
typename func_t,
typename array_t,
Expand Down Expand Up @@ -425,6 +521,68 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
#endif
}

#ifdef USE_ROCM
namespace {
// Static functor type checker for binary functors with
// float as the type of both parameters.
template<typename TupleLike, typename FirstParamTy, typename SecondParamTy, size_t arity, size_t arg_num=0>
struct check_binary_functor_types_for_specialization {
constexpr static inline bool check() {
bool current = false;
if constexpr (arity != 2) return false;
if constexpr (arg_num == 0) {
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType>)
return check_binary_functor_types_for_specialization<TupleLike, FirstParamTy, SecondParamTy, arity, arg_num+1>::check();
} else if constexpr (arg_num == 1) {
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType2>)
return check_binary_functor_types_for_specialization<TupleLike, FirstParamTy, SecondParamTy, arity, arg_num+1>::check();
}
return false;
}
};

// Bottom case: if we got this far, assume correct type matching except
// when there are no arguments (arity == 0).
template<typename TupleLike, typename FirstParamTy, typename SecondParamTy, size_t arity>
struct check_binary_functor_types_for_specialization<TupleLike, FirstParamTy, SecondParamTy, arity, arity> {
constexpr static inline bool check() {
if constexpr (arity != 0)
return true;
return false;
}
};

template<typename TupleLike, typename FirstParamTy, typename SecondParamTy>
struct check_binary_functor_types_for_specialization<TupleLike, FirstParamTy, SecondParamTy, 0, 0> {
constexpr static inline bool check() {
return false;
}
};

// The following is a list of type specializations for vectorized_templated
// elementwise kernel. It refers to the first and second runtime types of the
// arguments of a binary functor.
constexpr int number_of_binary_specializations = 4;
const std::array<std::array<c10::ScalarType, 2>, number_of_binary_specializations> rt_binary_specializations = {
{ {c10::CppTypeToScalarType<float>::value, c10::CppTypeToScalarType<BFloat16>::value},
{c10::CppTypeToScalarType<BFloat16>::value, c10::CppTypeToScalarType<float>::value},
{c10::CppTypeToScalarType<float>::value, c10::CppTypeToScalarType<Half>::value},
{c10::CppTypeToScalarType<Half>::value, c10::CppTypeToScalarType<float>::value} }
};

bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
if (iter.ninputs() != 2) return false;
for (int i = 0; i < 4; i++)
if (iter.input_dtype(0) == rt_binary_specializations[i][0] &&
iter.input_dtype(1) == rt_binary_specializations[i][1])
return true;
return false;
}
} // namespace anonymous
#endif

template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
if (!needs_dynamic_casting<func_t>::check(iter)) {
Expand All @@ -449,6 +607,55 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {

if (contiguous) {
#ifdef USE_ROCM
// Attempt to call specialized vectorized elementwise kernel
// that enables interleaving.
if (check_binary_rt_types_for_specialization(iter)) {
// !(numel%(vectorized_templated_config::block_work_size()*grid))) {
// constexpr to reduce the amount of kernels (empty) generated for
// unrolled templated elementwise and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr
(std::is_same_v<float,arg0_t> &&
traits::arity == 2 &&
check_binary_functor_types_for_specialization<func_tuple, float, float, traits::arity, /*current=*/0>::check()) {
// If we got here, we know we are in one of the specialized cases. We need to translate
// the runtime type to a statically known type. This is effectively hoisting to the host the switch over runtime
// type in the kernel in fetch_and_cast.
// Loader, storer, offset calculators are only needed for the reminder loop.
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithCast<traits::arity>(iter);
auto storer = memory::StoreWithCast<1>(iter);
if (iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<BFloat16>::value)
launch_vectorized_templated_kernel<func_t, at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer),
float, float, BFloat16>(numel, f, data,
input_offset_calculator, output_offset_calculator, loader, storer);
else if (iter.input_dtype(0) == c10::CppTypeToScalarType<BFloat16>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
launch_vectorized_templated_kernel<func_t, at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer),
float, BFloat16, float>(numel, f, data,
input_offset_calculator, output_offset_calculator, loader, storer);
else if (iter.input_dtype(0) == c10::CppTypeToScalarType<float>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<Half>::value)
launch_vectorized_templated_kernel<func_t, at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer),
float, float, Half>(numel, f, data,
input_offset_calculator, output_offset_calculator, loader, storer);
else if (iter.input_dtype(0) == c10::CppTypeToScalarType<Half>::value &&
iter.input_dtype(1) == c10::CppTypeToScalarType<float>::value)
launch_vectorized_templated_kernel<func_t, at::detail::Array<char*, ntensors>,
decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer),
float, Half, float>(numel, f, data,
input_offset_calculator, output_offset_calculator, loader, storer);
else
TORCH_CHECK(false, "unreachable");
return;
}
}
at::detail::Array<ScalarType, ntensors> dtypes;
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,34 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
policy.store(results, idx);
}

#ifdef USE_ROCM
template<int thread_work_size = thread_work_size(), typename func_t, typename policy_t>
__device__ inline void templated_elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;

int idx = blockIdx.x;

return_t results[thread_work_size];
args_t args[thread_work_size];

// load
policy.load(args, idx);

// compute
#pragma unroll
for (int i = 0; i < thread_work_size; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}

// store
policy.store(results, idx);
}
#endif

}} // namespace at::native

#include <ATen/native/cuda/CUDALoops.cuh>
Expand Down
Loading

0 comments on commit 34f7271

Please sign in to comment.