From 4194409f98216ea73cbbbcb2e0018d8182adc43b Mon Sep 17 00:00:00 2001 From: Carlo Bertolli Date: Mon, 3 Mar 2025 10:53:16 -0600 Subject: [PATCH] Support heterogeneous tensor types for vectorized elementwise kernel. This patch includes changes from https://github.com/pytorch/pytorch/pull/147527 with an extension to support multiple input tensor types. --- aten/src/ATen/native/cuda/CUDALoops.cuh | 357 +++++++++++++++++++++ aten/src/ATen/native/cuda/Loops.cuh | 32 ++ aten/src/ATen/native/cuda/MemoryAccess.cuh | 177 ++++++++-- 3 files changed, 546 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index bf98cf46277c71..878a8826855656 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -50,6 +50,20 @@ #define ASSERT_HOST_DEVICE_LAMBDA(type) #endif +namespace vectorized_templated_config { +constexpr int num_threads() { + return 256; +} + +constexpr int thread_work_size() { + return 64; +} + +constexpr int block_work_size() { + return thread_work_size() * num_threads(); +} +} // namespace vectorized_templated_config + namespace at { namespace native { @@ -147,6 +161,142 @@ static inline void launch_vectorized_kernel( } } +#ifdef USE_ROCM +template < + int vec_size, + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) +__global__ void vectorized_templated_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t inp_calc, + out_calc_t out_calc, + loader_t loader, + storer_t storer) { + using traits = function_traits; + int remaining = + N - vectorized_templated_config::block_work_size() * blockIdx.x; + if (remaining < + vectorized_templated_config::block_work_size()) { // if this block handles + // the reminder, + // just do a naive unrolled loop + auto policy = memory::policies::unroll_base< + vectorized_templated_config::thread_work_size(), + vectorized_templated_config::num_threads(), + vectorized_templated_config::block_work_size(), + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t>(data, remaining, inp_calc, out_calc, loader, storer); + templated_elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + templated_elementwise_kernel_helper< + vectorized_templated_config::thread_work_size()>( + f, + memory::policies::vectorized_templated< + vectorized_templated_config::thread_work_size(), + vectorized_templated_config::num_threads(), + vectorized_templated_config::block_work_size(), + vec_size, + array_t, + OutputType, + InputTypes...>(data)); + } +} + +// This function assume trivial 1d and supports template specialization +// to avoid dynamic casting. +// Input vectorization size is based on runtime information, i.e. +// the actual data types of the input and output tensor and cannot +// be determined using the functor type, as in regular non-templated +// vectorized kernels. The caller is in charge of selecting the correct input +// vectorization length. +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + typename OutputType, + typename... InputTypes> +static inline void launch_vectorized_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / + vectorized_templated_config::block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + switch (vec_size) { + case 8: + vectorized_templated_elementwise_kernel< + 8, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_templated_elementwise_kernel< + 4, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_templated_elementwise_kernel< + 2, + func_t, + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + OutputType, + InputTypes...> + <<>>( + N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + // vector size 1 is not handled as part of vectorize_templated kernel + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} +#endif + template < typename func_t, typename array_t, @@ -425,6 +575,104 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { #endif } +#ifdef USE_ROCM +namespace { +// Static functor type checker for binary functors with +// float as the type of both parameters. +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity, + size_t arg_num = 0> +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + bool current = false; + if constexpr (arity != 2) + return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arg_num + 1>::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template < + typename TupleLike, + typename FirstParamTy, + typename SecondParamTy, + size_t arity> +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + arity, + arity> { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_binary_functor_types_for_specialization< + TupleLike, + FirstParamTy, + SecondParamTy, + 0, + 0> { + constexpr static inline bool check() { + return false; + } +}; + +// The following is a list of type specializations for vectorized_templated +// elementwise kernel. It refers to the first and second runtime types of the +// arguments of a binary functor. +constexpr int number_of_binary_specializations = 4; +const std:: + array, number_of_binary_specializations> + rt_binary_specializations = { + {{c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value}}}; + +bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { + if (iter.ninputs() != 2) + return false; + for (int i = 0; i < 4; i++) + if (iter.input_dtype(0) == rt_binary_specializations[i][0] && + iter.input_dtype(1) == rt_binary_specializations[i][1]) + return true; + return false; +} +} // namespace +#endif + template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (!needs_dynamic_casting::check(iter)) { @@ -449,6 +697,115 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (contiguous) { #ifdef USE_ROCM + // Attempt to call specialized vectorized elementwise kernel + // that enables interleaving. + if (check_binary_rt_types_for_specialization(iter) && + memory::can_vectorize_up_to(data) > 1) { + // constexpr to reduce the amount of kernels (empty) generated for + // unrolled templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*current=*/0>::check()) { + // If we got here, we know we are in one of the specialized cases. We + // need to translate the runtime type to a statically known type. This + // is effectively hoisting to the host the switch over runtime type in + // the kernel in fetch_and_cast. Loader, storer, offset calculators are + // only needed for the reminder loop. + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + if (iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel< + func_t, + at::detail::Array, + decltype(input_offset_calculator), + decltype(output_offset_calculator), + decltype(loader), + decltype(storer), + float, + float, + BFloat16>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + else if ( + iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel< + func_t, + at::detail::Array, + decltype(input_offset_calculator), + decltype(output_offset_calculator), + decltype(loader), + decltype(storer), + float, + BFloat16, + float>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + else if ( + iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel< + func_t, + at::detail::Array, + decltype(input_offset_calculator), + decltype(output_offset_calculator), + decltype(loader), + decltype(storer), + float, + float, + Half>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + else if ( + iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel< + func_t, + at::detail::Array, + decltype(input_offset_calculator), + decltype(output_offset_calculator), + decltype(loader), + decltype(storer), + float, + Half, + float>( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); + else + TORCH_CHECK(false, "unreachable"); + return; + } + } at::detail::Array dtypes; auto inner_strides = iter.get_inner_strides(); at::detail::Array strides; diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index cb14f275e21718..4bf9860239e0d1 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -66,6 +66,38 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { policy.store(results, idx); } +#ifdef USE_ROCM +template < + int thread_work_size = thread_work_size(), + typename func_t, + typename policy_t> +__device__ inline void templated_elementwise_kernel_helper( + func_t f, + policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size]; + args_t args[thread_work_size]; + + // load + policy.load(args, idx); + +// compute +#pragma unroll + for (int i = 0; i < thread_work_size; i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} +#endif }} // namespace at::native #include diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 1662d58789a72c..e81bc28771acd0 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -67,6 +67,26 @@ struct vectorized_load_helper { } }; +// Templated version of vectorized load helper. +// It can be used on heterogeneous input tensor element types. +template +struct vectorized_templated_load_helper { + template + static __device__ void apply(policy_t& self, args_t* args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + + // Delay pointer arithmetic to the policy loader where we know the actual + // type of the current argument. + char* ptr = (self.data[arg_index + 1]); + auto args_accessor = [&args] __device__(int thread_unroll_idx) -> arg_t& { + return std::get(args[thread_unroll_idx]); + }; + self.template load_single_arg(args_accessor, ptr, idx); + } +}; + template struct unroll_load_helper { template @@ -183,9 +203,17 @@ namespace policies { // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors -template -struct unroll { - +template < + int num_threads, + int thread_work_size, + int block_work_size, + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int num_outputs = 1> +struct unroll_base { data_t data; int remaining; inp_calc_t input_offset_calculator; @@ -193,11 +221,22 @@ struct unroll { loader_t loader; storer_t storer; - __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): - data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {} + __device__ unroll_base( + data_t data, + int remaining, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) + : data(data), + remaining(remaining), + input_offset_calculator(ic), + output_offset_calculator(oc), + loader(l), + storer(s) {} __device__ inline bool check_inbounds(int thread_work_elem) { - return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + return ((int)(threadIdx.x + thread_work_elem * num_threads) < remaining); } template @@ -205,14 +244,14 @@ struct unroll { constexpr int arity = std::tuple_size::value; int thread_idx = threadIdx.x; #pragma unroll - for (int i = 0; i < thread_work_size(); i++) { - if (thread_idx >= remaining) { - return; + for (int i = 0; i < thread_work_size; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args( + *this, args, offset, loader, i, num_outputs); + thread_idx += num_threads; } - int linear_idx = thread_idx + block_work_size() * idx; - auto offset = input_offset_calculator.get(linear_idx); - detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); - thread_idx += num_threads(); } } @@ -220,18 +259,36 @@ struct unroll { __device__ inline void store(scalar_t *from, int idx) { int thread_idx = threadIdx.x; #pragma unroll - for (int i = 0; i < thread_work_size(); i++) { - if (thread_idx >= remaining) { - return; + for (int i = 0; i < thread_work_size; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads; } - int linear_idx = thread_idx + block_work_size() * idx; - int offset = output_offset_calculator.get(linear_idx)[0]; - storer.store(from[i], data[0], offset); - thread_idx += num_threads(); } } }; +// Same as unroll_base, but uses configuration from current context. +template < + typename data_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t, + int num_outputs = 1> +using unroll = unroll_base< + num_threads(), + thread_work_size(), + block_work_size(), + data_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t, + num_outputs>; + // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors // Note: @@ -289,6 +346,86 @@ struct vectorized { } }; +#ifdef USE_ROCM +// This is similar to vectorized policy above, but this one supports +// heterogenous input tensor types as templated parameters. +// Its use should be limited to frequently used heterogeneous data types +// as each instantiation will generate a separate kernel, leading to code +// bloating if applied to all combinations supported in PyTorch. Assumption: all +// tensors are contiguous, that is: stride == sizeof(type) for all tensors +template < + int thread_work_size, + int num_threads, + int block_work_size, + int vec_size, + typename data_t, + typename CastToT, + typename... CastFromTs> // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized_templated { + static_assert( + thread_work_size % vec_size == 0, + "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = thread_work_size / vec_size; + + data_t data; + + __device__ vectorized_templated(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load(args_t* args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll:: + with_args(*this, args, idx); + } + + template + __device__ inline void load_single_arg(accessor_t to, char* ptr, int idx) { + // extract the arg_index-th input tensor element type from the + // variadic template argument. + using CastFromT = + std::tuple_element_t>; + // Delayed pointer arithmetic from the caller: this is the place + // where we know the type of the argument. + CastFromT* block_ptr = + reinterpret_cast(ptr) + block_work_size * idx; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + auto v = load_vector(block_ptr, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = c10::convert(v.val[j]); + } + } + } + + // Assume for now that from (temporary array per thread) is of the same + // type as to (destination tensor), which is the case for + // float(float,bfloat16) and functor add on float(float,float). + template + __device__ inline void store(scalar_t* from, int idx) { + using vec_t = aligned_vector; + scalar_t* to = reinterpret_cast(data[0]) + block_work_size * idx; + vec_t* to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; +#endif + template struct multi_outputs_unroll { //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct