Skip to content

Commit

Permalink
Revert "[release/2.5] aten::copy optimization. (#1862)"
Browse files Browse the repository at this point in the history
This reverts commit 4ed5c3a.
  • Loading branch information
dnikolaev-amd authored and pruthvistony committed Mar 4, 2025
1 parent 4b826b3 commit 5eaa466
Showing 1 changed file with 0 additions and 172 deletions.
172 changes: 0 additions & 172 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,176 +137,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
}
}

// This API is for detecting whether the permute parameter of a three-dimensional tensor
// in the Copy operation from src to dst is from [0, 1, 2] to [0, 2, 1].
bool is_permute_021(TensorIteratorBase &iter) {
const auto& input = iter.tensor(1);
const auto& output = iter.tensor(0);
bool is_permute = false;
if (input.dim() == 3) {
is_permute = true;
is_permute &= input.dim() == output.dim();

is_permute &= input.stride(0) == input.size(2) * input.stride(2);
is_permute &= input.stride(1) == 1;
is_permute &= input.stride(2) >= input.size(1);
is_permute &= output.is_contiguous();
}
return is_permute;
}

template<typename T, int BLOCK_SIZE, int BIG_TILE_SIZE_N, int BIG_TILE_SIZE_K, int M_SWIZZLE>
__global__ void transpose_tile_big_kernel(const void* __restrict a, void* __restrict c, const int N, const int K, const int STRIDE_N_ELE)
{
constexpr uint32_t elements_in_128b = 16 / sizeof(T);
union BLOCK_16B
{
T e[elements_in_128b];
__uint128_t ow;
};

constexpr int LDS_PAD = (4 / sizeof(T));
// Round up processing to next full tile
const uint32_t n_tiles = (N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N;
const uint32_t k_tiles = (K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K;
const uint32_t nk_tiles = n_tiles * k_tiles;
const uint32_t m_tiles = gridDim.x / nk_tiles;
const uint32_t m_tile_swizzle = blockIdx.x / nk_tiles / M_SWIZZLE * M_SWIZZLE;
/// do m_swizzle when there are enough m_tiles
const bool swizzle_m = m_tile_swizzle + M_SWIZZLE <= m_tiles;
const uint32_t current_m = swizzle_m ? m_tile_swizzle + blockIdx.x % M_SWIZZLE : blockIdx.x / nk_tiles;
const uint64_t stride_n = STRIDE_N_ELE * sizeof(T);
const uint64_t stride_k = N * sizeof(T);
const uint64_t out_stride_nk = N * K * sizeof(T);
const uint64_t in_stride_nk = N * STRIDE_N_ELE * sizeof(T);

const uint32_t current_nk = swizzle_m ? blockIdx.x / M_SWIZZLE % nk_tiles : blockIdx.x % nk_tiles;
const uint32_t ti = current_nk / k_tiles;
const uint32_t tj = current_nk % k_tiles;

__shared__ T smem[BIG_TILE_SIZE_N][BIG_TILE_SIZE_K + LDS_PAD];

// Detect partial tiles
const uint32_t current_n_size = (ti == (n_tiles - 1) && (N % BIG_TILE_SIZE_N) != 0) ? (N % BIG_TILE_SIZE_N) : BIG_TILE_SIZE_N;
const uint32_t current_k_size = (tj == (k_tiles - 1) && (K % BIG_TILE_SIZE_K) != 0) ? (K % BIG_TILE_SIZE_K) : BIG_TILE_SIZE_K;
//use 128bit load&store whenever possible
if (current_n_size % 8 == 0 && current_k_size % 8 == 0)
{
// Copy full tile with large loads
constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T);
constexpr uint32_t ld_per_row = row_bytes / sizeof(__uint128_t);
constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row;
constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg;
// Make sure WG isn't too large
static_assert(vmem_per_thread >= 1);

const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < vmem_per_thread; t++)
{
uint32_t col = threadIdx.x % ld_per_row;
uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg;
uint64_t offset = (col * 8 < current_k_size && row < current_n_size) ?
row * stride_n + col * sizeof(__uint128_t) : 0;
const __uint128_t* pfa = (const __uint128_t*)(pat + offset);
BLOCK_16B d;
d.ow = *pfa;
#pragma unroll
for (uint32_t i = 0; i < elements_in_128b; i++)
{
smem[row][col * elements_in_128b + i] = d.e[i];
}
}
__syncthreads();
// Copy full tile with large loads
constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T);
constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t);
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
// Make sure WG isn't too large
static_assert(wr_per_row >= 1);
const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < wr_per_row; t++)
{
uint32_t col = threadIdx.x % vmem_per_row_wr;
uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr;
if (col * 8 < current_n_size && row < current_k_size)
{
uint64_t offset = row * stride_k + col * sizeof(__uint128_t);
BLOCK_16B d;
// Transpose tile on read from LDS
#pragma unroll
for (uint32_t i = 0; i < elements_in_128b; i++)
{
d.e[i] = smem[col * elements_in_128b + i][row];
}
__uint128_t* pfc = (__uint128_t*)(pc + offset);
*pfc = d.ow;
}
}
}
else
{
// Copy partial tiles with element accesses
constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T);
constexpr uint32_t ld_per_row = BIG_TILE_SIZE_K;
constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row;
constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg;
// Make sure WG isn't too large
static_assert(vmem_per_thread >= 1);

const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < vmem_per_thread; t++)
{
uint32_t col = threadIdx.x % ld_per_row;
uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg;
uint64_t offset = (col < current_k_size && row < current_n_size) ? row * stride_n + col * 2 : 0;
const uint16_t* pfa = (const uint16_t*)(pat + offset);
smem[row][col] = *pfa;
}
__syncthreads();
// Copy full tile with large loads
constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T);
constexpr uint32_t vmem_per_row_wr = BIG_TILE_SIZE_N;
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < wr_per_row; t++)
{
uint32_t col = threadIdx.x % vmem_per_row_wr;
uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr;
if (col < current_n_size && row < current_k_size)
{
uint64_t offset = row * stride_k + col * 2;
uint16_t* pfc = (uint16_t*)(pc + offset);
*pfc = smem[col][row];
}
}
}
}

void transpose_last2dim(TensorIteratorBase &iter) {
void* dst = iter.data_ptr(0);
void* src = iter.data_ptr(1);
const auto& input = iter.tensor(1);

int M = input.size(0);
int N = input.size(1);
int K = input.size(2);

auto stream = c10::cuda::getCurrentCUDAStream();
constexpr uint32_t BIG_TILE_SIZE_N = 64;
constexpr uint32_t BIG_TILE_SIZE_K = 64;
constexpr uint32_t M_SWIZZLE = 8;
const int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K);
const dim3 grid_dim(grid_x, 1, 1);
const dim3 block_dim(256, 1, 1);
transpose_tile_big_kernel<uint16_t, 256, BIG_TILE_SIZE_N, BIG_TILE_SIZE_K, M_SWIZZLE><<<grid_dim, block_dim, 0, stream>>>(src, dst, K, N, input.stride(2));
}

// TODO: We probably can use the opaque type trick to avoid creating duplicate
// kernels for equivalent bit lengths
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
Expand All @@ -323,8 +153,6 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf)) {
transpose_last2dim(iter);
} else {
AT_DISPATCH_V2(
dtype, "copy_", AT_WRAP([&] {
Expand Down

0 comments on commit 5eaa466

Please sign in to comment.