From fa661c9616621bc3da10d0b027cd8c6c042bb6af Mon Sep 17 00:00:00 2001 From: Carlo Bertolli Date: Mon, 3 Mar 2025 10:53:16 -0600 Subject: [PATCH] Support heterogeneous tensor types for vectorized elementwise kernel. This patch includes changes from https://github.com/pytorch/pytorch/pull/147527 with an extension to support multiple input tensor types. --- aten/src/ATen/native/cuda/CUDALoops.cuh | 206 +++++++++++++++++++++ aten/src/ATen/native/cuda/Loops.cuh | 28 +++ aten/src/ATen/native/cuda/MemoryAccess.cuh | 128 +++++++++++-- 3 files changed, 344 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index bf98cf46277c71..de3d8ecdcda0d3 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -50,6 +50,16 @@ #define ASSERT_HOST_DEVICE_LAMBDA(type) #endif +namespace vectorized_templated_config { +constexpr int num_threads() { + return 256; +} + +constexpr int thread_work_size() { return 64; } + +constexpr int block_work_size() { return thread_work_size() * num_threads(); } +} // namespace vectorized_templated_config + namespace at { namespace native { @@ -147,6 +157,92 @@ static inline void launch_vectorized_kernel( } } +#ifdef USE_ROCM +template +C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) +__global__ void vectorized_templated_elementwise_kernel(int N, func_t f, array_t data, + inp_calc_t inp_calc, out_calc_t out_calc, loader_t loader, storer_t storer) { + using traits = function_traits; + int remaining = N - vectorized_templated_config::block_work_size() * blockIdx.x; + if (remaining < vectorized_templated_config::block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto policy = memory::policies::unroll_base< + vectorized_templated_config::thread_work_size(), + vectorized_templated_config::num_threads(), + vectorized_templated_config::block_work_size(), + array_t, + inp_calc_t, + out_calc_t, + loader_t, + storer_t>(data, remaining, inp_calc, out_calc, loader, storer); + templated_elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + templated_elementwise_kernel_helper( + f, memory::policies::vectorized_templated(data)); + } +} + + +// This function assume trivial 1d and supports template specialization +// to avoid dynamic casting. +// Input vectorization size is based on runtime information, i.e. +// the actual data types of the input and output tensor and cannot +// be determined using the functor type, as in regular non-templated +// vectorized kernels. The caller is in charge of selecting the correct input +// vectorization length. +template +static inline void launch_vectorized_templated_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / vectorized_templated_config::block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + switch (vec_size) { + case 8: + vectorized_templated_elementwise_kernel<8, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 4: + vectorized_templated_elementwise_kernel<4, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_templated_elementwise_kernel<2, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + // vector size 1 is not handled as part of vectorize_templated kernel + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} +#endif + template < typename func_t, typename array_t, @@ -425,6 +521,68 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { #endif } +#ifdef USE_ROCM +namespace { +// Static functor type checker for binary functors with +// float as the type of both parameters. +template +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + bool current = false; + if constexpr (arity != 2) return false; + if constexpr (arg_num == 0) { + using SelectedType = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization::check(); + } else if constexpr (arg_num == 1) { + using SelectedType2 = std::tuple_element_t; + if constexpr (std::is_same_v) + return check_binary_functor_types_for_specialization::check(); + } + return false; + } +}; + +// Bottom case: if we got this far, assume correct type matching except +// when there are no arguments (arity == 0). +template +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + if constexpr (arity != 0) + return true; + return false; + } +}; + +template +struct check_binary_functor_types_for_specialization { + constexpr static inline bool check() { + return false; + } +}; + +// The following is a list of type specializations for vectorized_templated +// elementwise kernel. It refers to the first and second runtime types of the +// arguments of a binary functor. +constexpr int number_of_binary_specializations = 4; +const std::array, number_of_binary_specializations> rt_binary_specializations = { + { {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}, + {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value} } +}; + +bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { + if (iter.ninputs() != 2) return false; + for (int i = 0; i < 4; i++) + if (iter.input_dtype(0) == rt_binary_specializations[i][0] && + iter.input_dtype(1) == rt_binary_specializations[i][1]) + return true; + return false; +} +} // namespace anonymous +#endif + template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (!needs_dynamic_casting::check(iter)) { @@ -449,6 +607,54 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (contiguous) { #ifdef USE_ROCM + // Attempt to call specialized vectorized elementwise kernel + // that enables interleaving. + if (check_binary_rt_types_for_specialization(iter) && memory::can_vectorize_up_to(data) > 1) { + // constexpr to reduce the amount of kernels (empty) generated for + // unrolled templated elementwise and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr + (std::is_same_v && + traits::arity == 2 && + check_binary_functor_types_for_specialization::check()) { + // If we got here, we know we are in one of the specialized cases. We need to translate + // the runtime type to a statically known type. This is effectively hoisting to the host the switch over runtime + // type in the kernel in fetch_and_cast. + // Loader, storer, offset calculators are only needed for the reminder loop. + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + if (iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel, + decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer), + float, float, BFloat16>(numel, f, data, + input_offset_calculator, output_offset_calculator, loader, storer); + else if (iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel, + decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer), + float, BFloat16, float>(numel, f, data, + input_offset_calculator, output_offset_calculator, loader, storer); + else if (iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel, + decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer), + float, float, Half>(numel, f, data, + input_offset_calculator, output_offset_calculator, loader, storer); + else if (iter.input_dtype(0) == c10::CppTypeToScalarType::value && + iter.input_dtype(1) == c10::CppTypeToScalarType::value) + launch_vectorized_templated_kernel, + decltype(input_offset_calculator), decltype(output_offset_calculator), decltype(loader), decltype(storer), + float, Half, float>(numel, f, data, + input_offset_calculator, output_offset_calculator, loader, storer); + else + TORCH_CHECK(false, "unreachable"); + return; + } + } at::detail::Array dtypes; auto inner_strides = iter.get_inner_strides(); at::detail::Array strides; diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index cb14f275e21718..cccf9b9397e3b7 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -66,6 +66,34 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { policy.store(results, idx); } +#ifdef USE_ROCM +template +__device__ inline void templated_elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size]; + args_t args[thread_work_size]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < thread_work_size; i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); +} +#endif + }} // namespace at::native #include diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 1662d58789a72c..463dc9ad372edb 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -67,6 +67,25 @@ struct vectorized_load_helper { } }; +// Templated version of vectorized load helper. +// It can be used on heterogeneous input tensor element types. +template +struct vectorized_templated_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + + // Delay pointer arithmetic to the policy loader where we know the actual + // type of the current argument. + char *ptr = (self.data[arg_index + 1]); + auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; + self.template load_single_arg(args_accessor, ptr, idx); + } +}; + + template struct unroll_load_helper { template @@ -183,8 +202,8 @@ namespace policies { // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors -template -struct unroll { +template +struct unroll_base { data_t data; int remaining; @@ -193,11 +212,11 @@ struct unroll { loader_t loader; storer_t storer; - __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): + __device__ unroll_base(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {} __device__ inline bool check_inbounds(int thread_work_elem) { - return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + return ((int)(threadIdx.x + thread_work_elem*num_threads) < remaining); } template @@ -205,14 +224,13 @@ struct unroll { constexpr int arity = std::tuple_size::value; int thread_idx = threadIdx.x; #pragma unroll - for (int i = 0; i < thread_work_size(); i++) { - if (thread_idx >= remaining) { - return; + for (int i = 0; i < thread_work_size; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads; } - int linear_idx = thread_idx + block_work_size() * idx; - auto offset = input_offset_calculator.get(linear_idx); - detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); - thread_idx += num_threads(); } } @@ -220,18 +238,21 @@ struct unroll { __device__ inline void store(scalar_t *from, int idx) { int thread_idx = threadIdx.x; #pragma unroll - for (int i = 0; i < thread_work_size(); i++) { - if (thread_idx >= remaining) { - return; + for (int i = 0; i < thread_work_size; i++) { + if (thread_idx < remaining) { + int linear_idx = thread_idx + block_work_size * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads; } - int linear_idx = thread_idx + block_work_size() * idx; - int offset = output_offset_calculator.get(linear_idx)[0]; - storer.store(from[i], data[0], offset); - thread_idx += num_threads(); } } }; +// Same as unroll_base, but uses configuration from current context. +template +using unroll = unroll_base; + // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors // Note: @@ -289,6 +310,77 @@ struct vectorized { } }; +#ifdef USE_ROCM +// This is similar to vectorized policy above, but this one supports +// heterogenous input tensor types as templated parameters. +// Its use should be limited to frequently used heterogeneous data types +// as each instantiation will generate a separate kernel, leading to code bloating +// if applied to all combinations supported in PyTorch. +// Assumption: +// all tensors are contiguous, that is: stride == sizeof(type) for all tensors +template // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized_templated { + + static_assert(thread_work_size % vec_size == 0, "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = thread_work_size / vec_size; + + data_t data; + + __device__ vectorized_templated(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll::with_args(*this, args, idx); + } + + template + __device__ inline void load_single_arg(accessor_t to, char *ptr, int idx) { + // extract the arg_index-th input tensor element type from the + // variadic template argument. + using CastFromT = std::tuple_element_t>; + // Delayed pointer arithmetic from the caller: this is the place + // where we know the type of the argument. + CastFromT *block_ptr = reinterpret_cast(ptr) + block_work_size * idx; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + auto v = load_vector(block_ptr, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = c10::convert(v.val[j]); + } + } + } + + // Assume for now that from (temporary array per thread) is of the same + // type as to (destination tensor), which is the case for float(float,bfloat16) + // and functor add on float(float,float). + template + __device__ inline void store(scalar_t *from, int idx) { + using vec_t = aligned_vector; + scalar_t *to = reinterpret_cast(data[0]) + block_work_size * idx; + vec_t *to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads; + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; +#endif + template struct multi_outputs_unroll { //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct