From 5eaa4666fbb0a74dfdc130cac678c9709674be5e Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev Date: Tue, 4 Mar 2025 18:38:16 +0000 Subject: [PATCH] Revert "[release/2.5] aten::copy optimization. (#1862)" This reverts commit 4ed5c3a7a47c98d960f6ad6d0fea48dd35885b8e. --- aten/src/ATen/native/cuda/Copy.cu | 172 ------------------------------ 1 file changed, 172 deletions(-) diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 458bce7d896ef5..fad81d59d45c9c 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -137,176 +137,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { } } -// This API is for detecting whether the permute parameter of a three-dimensional tensor -// in the Copy operation from src to dst is from [0, 1, 2] to [0, 2, 1]. -bool is_permute_021(TensorIteratorBase &iter) { - const auto& input = iter.tensor(1); - const auto& output = iter.tensor(0); - bool is_permute = false; - if (input.dim() == 3) { - is_permute = true; - is_permute &= input.dim() == output.dim(); - - is_permute &= input.stride(0) == input.size(2) * input.stride(2); - is_permute &= input.stride(1) == 1; - is_permute &= input.stride(2) >= input.size(1); - is_permute &= output.is_contiguous(); - } - return is_permute; -} - -template -__global__ void transpose_tile_big_kernel(const void* __restrict a, void* __restrict c, const int N, const int K, const int STRIDE_N_ELE) -{ - constexpr uint32_t elements_in_128b = 16 / sizeof(T); - union BLOCK_16B - { - T e[elements_in_128b]; - __uint128_t ow; - }; - - constexpr int LDS_PAD = (4 / sizeof(T)); - // Round up processing to next full tile - const uint32_t n_tiles = (N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N; - const uint32_t k_tiles = (K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K; - const uint32_t nk_tiles = n_tiles * k_tiles; - const uint32_t m_tiles = gridDim.x / nk_tiles; - const uint32_t m_tile_swizzle = blockIdx.x / nk_tiles / M_SWIZZLE * M_SWIZZLE; - /// do m_swizzle when there are enough m_tiles - const bool swizzle_m = m_tile_swizzle + M_SWIZZLE <= m_tiles; - const uint32_t current_m = swizzle_m ? m_tile_swizzle + blockIdx.x % M_SWIZZLE : blockIdx.x / nk_tiles; - const uint64_t stride_n = STRIDE_N_ELE * sizeof(T); - const uint64_t stride_k = N * sizeof(T); - const uint64_t out_stride_nk = N * K * sizeof(T); - const uint64_t in_stride_nk = N * STRIDE_N_ELE * sizeof(T); - - const uint32_t current_nk = swizzle_m ? blockIdx.x / M_SWIZZLE % nk_tiles : blockIdx.x % nk_tiles; - const uint32_t ti = current_nk / k_tiles; - const uint32_t tj = current_nk % k_tiles; - - __shared__ T smem[BIG_TILE_SIZE_N][BIG_TILE_SIZE_K + LDS_PAD]; - - // Detect partial tiles - const uint32_t current_n_size = (ti == (n_tiles - 1) && (N % BIG_TILE_SIZE_N) != 0) ? (N % BIG_TILE_SIZE_N) : BIG_TILE_SIZE_N; - const uint32_t current_k_size = (tj == (k_tiles - 1) && (K % BIG_TILE_SIZE_K) != 0) ? (K % BIG_TILE_SIZE_K) : BIG_TILE_SIZE_K; - //use 128bit load&store whenever possible - if (current_n_size % 8 == 0 && current_k_size % 8 == 0) - { - // Copy full tile with large loads - constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T); - constexpr uint32_t ld_per_row = row_bytes / sizeof(__uint128_t); - constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row; - constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg; - // Make sure WG isn't too large - static_assert(vmem_per_thread >= 1); - - const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk; - #pragma unroll - for (uint32_t t = 0; t < vmem_per_thread; t++) - { - uint32_t col = threadIdx.x % ld_per_row; - uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg; - uint64_t offset = (col * 8 < current_k_size && row < current_n_size) ? - row * stride_n + col * sizeof(__uint128_t) : 0; - const __uint128_t* pfa = (const __uint128_t*)(pat + offset); - BLOCK_16B d; - d.ow = *pfa; - #pragma unroll - for (uint32_t i = 0; i < elements_in_128b; i++) - { - smem[row][col * elements_in_128b + i] = d.e[i]; - } - } - __syncthreads(); - // Copy full tile with large loads - constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T); - constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t); - constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; - constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; - // Make sure WG isn't too large - static_assert(wr_per_row >= 1); - const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk; - #pragma unroll - for (uint32_t t = 0; t < wr_per_row; t++) - { - uint32_t col = threadIdx.x % vmem_per_row_wr; - uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr; - if (col * 8 < current_n_size && row < current_k_size) - { - uint64_t offset = row * stride_k + col * sizeof(__uint128_t); - BLOCK_16B d; - // Transpose tile on read from LDS - #pragma unroll - for (uint32_t i = 0; i < elements_in_128b; i++) - { - d.e[i] = smem[col * elements_in_128b + i][row]; - } - __uint128_t* pfc = (__uint128_t*)(pc + offset); - *pfc = d.ow; - } - } - } - else - { - // Copy partial tiles with element accesses - constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T); - constexpr uint32_t ld_per_row = BIG_TILE_SIZE_K; - constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row; - constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg; - // Make sure WG isn't too large - static_assert(vmem_per_thread >= 1); - - const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk; - #pragma unroll - for (uint32_t t = 0; t < vmem_per_thread; t++) - { - uint32_t col = threadIdx.x % ld_per_row; - uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg; - uint64_t offset = (col < current_k_size && row < current_n_size) ? row * stride_n + col * 2 : 0; - const uint16_t* pfa = (const uint16_t*)(pat + offset); - smem[row][col] = *pfa; - } - __syncthreads(); - // Copy full tile with large loads - constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T); - constexpr uint32_t vmem_per_row_wr = BIG_TILE_SIZE_N; - constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; - constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; - const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk; - #pragma unroll - for (uint32_t t = 0; t < wr_per_row; t++) - { - uint32_t col = threadIdx.x % vmem_per_row_wr; - uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr; - if (col < current_n_size && row < current_k_size) - { - uint64_t offset = row * stride_k + col * 2; - uint16_t* pfc = (uint16_t*)(pc + offset); - *pfc = smem[col][row]; - } - } - } -} - -void transpose_last2dim(TensorIteratorBase &iter) { - void* dst = iter.data_ptr(0); - void* src = iter.data_ptr(1); - const auto& input = iter.tensor(1); - - int M = input.size(0); - int N = input.size(1); - int K = input.size(2); - - auto stream = c10::cuda::getCurrentCUDAStream(); - constexpr uint32_t BIG_TILE_SIZE_N = 64; - constexpr uint32_t BIG_TILE_SIZE_K = 64; - constexpr uint32_t M_SWIZZLE = 8; - const int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K); - const dim3 grid_dim(grid_x, 1, 1); - const dim3 block_dim(256, 1, 1); - transpose_tile_big_kernel<<>>(src, dst, K, N, input.stride(2)); -} - // TODO: We probably can use the opaque type trick to avoid creating duplicate // kernels for equivalent bit lengths void direct_copy_kernel_cuda(TensorIteratorBase &iter) { @@ -323,8 +153,6 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); - } else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf)) { - transpose_last2dim(iter); } else { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] {