Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Feb 12, 2025
1 parent ced454e commit 027c8dd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
43 changes: 24 additions & 19 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,34 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt;
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;

const int m_block_idx = blockIdx.x;
const int batch_idx = blockIdx.y;
const int tidx = threadIdx.x;

MLATile<Params> tile(params);

// ProblemShape
// Q/O: (q_packed_len, HEAD_DIM)
// KV: (kv_len, HEAD_DIM)
// Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM)
auto [Q, Q_ROPE, O] = tile.template get_qo_tile<DType>(blockIdx.y);
auto [KV, K_ROPE] = tile.template get_kv_tile<DType>(blockIdx.y);
auto [Q, Q_ROPE, O] = tile.template get_qo_tile<DType>(batch_idx);
auto [KV, K_ROPE] = tile.template get_kv_tile<DType>(batch_idx);

if (blockIdx.x * kBlockM >= size<0>(Q)) {
if (m_block_idx * kBlockM >= size<0>(Q)) {
// m out of bound, return
return;
}

// Gmem
// (BLK_M, BLK_K, STAGES)
Tensor gQ = local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(blockIdx.x, _));
Tensor gO = local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(blockIdx.x, _));
Tensor gQ = local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
Tensor gO = local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
// (BLK_N, BLK_K, n, STAGES)
Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _));

// (BLK_M, ROPE_HEAD_DIM)
Tensor gQ_rope = local_tile(
Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(blockIdx.x, _0{}));
Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
// (BLK_N, ROPE_HEAD_DIM, n)
Tensor gK_rope =
local_tile(K_ROPE, Shape<_BLK_N, _ROPE_HEAD_DIM>{}, make_coord(_, _0{}));
Expand Down Expand Up @@ -117,8 +121,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for qkv
GmemTiledCopyQ gmem_tiled_copy_Q;
GmemTiledCopyKV gmem_tiled_copy_KV;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(threadIdx.x);
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(threadIdx.x);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx);
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx);

auto produce_q = [&](int stage) {
// gQ/sQ: (BLK_M, BLK_K, STAGES)
Expand Down Expand Up @@ -153,7 +157,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
};

TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(threadIdx.x);
auto thr_mma = tiled_mma.get_slice(tidx);
// GEMM-I: S = Q@K.T
// sQ/sK: (BLK_M, BLK_K, STAGES)
auto tSrQ = partition_fragment_A(
Expand All @@ -168,7 +172,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

// s2r tiled copy for qkv
SmemTiledCopyQ smem_tiled_copy_Q;
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(threadIdx.x);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx);
// (CPY, CPY_M, CPY_K, STAGES)
auto tCsQ = smem_thr_copy_Q.partition_S(sQ);
// (CPY, CPY_M, _2)
Expand All @@ -180,7 +184,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);

SmemTiledCopyK smem_tiled_copy_K;
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(threadIdx.x);
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
auto tCsK = smem_thr_copy_K.partition_S(sK);
// (CPY, CPY_N, _2)
Expand Down Expand Up @@ -241,7 +245,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tOrVt = partition_fragment_B(thr_mma, sVt(_, _, _0{}), _, _2{});

SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(threadIdx.x);
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
auto tCsVt = smem_thr_copy_Vt.partition_S(sVt);
// (CPY, CPY_N, _2)
Expand All @@ -257,6 +261,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tCsVt_s = tCsVt(_, _, _, s);
// tCsVt_s: (CPY, CPY_N, CPY_K) => tCrVt: (CPY, CPY_N, _2)
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{}));

auto tOrS_k = make_tensor_like<DType>(tOrS(_, _, _0{}));
CUTE_UNROLL
for (int k = 0; k < size<2>(tOrS); ++k) {
if (k != size<2>(tOrS) - 1) {
Expand All @@ -265,7 +271,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
tCsVt_s(_, _, next_k),
tCrVt(_, _, (next_k & 1)));
}
cute::gemm(tiled_mma, tOrS(_, _, k), tOrVt(_, _, (k & 1)), tOrO_s);
// cast scores from Accumulator to Element
fast_cast(tOrS(_, _, k), tOrS_k);
cute::gemm(tiled_mma, tOrS_k, tOrVt(_, _, (k & 1)), tOrO_s);
}
};

Expand All @@ -276,7 +284,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// (BLK_M, BLK_K, STAGES)
auto sO = make_tensor(sQ.data(), SmemLayoutO{});
SmemTiledCopyO smem_tiled_copy_O;
auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(threadIdx.x);
auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx);
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
auto tOrO_s = tOrO(_, _, _, s);
Expand All @@ -296,7 +304,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// 2. copy output from smem to gmem
{
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(threadIdx.x);
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx);

auto tCsO = gmem_thr_copy_O.partition_S(sO);
auto tCgO = gmem_thr_copy_O.partition_D(gO);
Expand Down Expand Up @@ -360,11 +368,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
softmax.rescale(tSrS_mn, tOrO_mn);

// 3> O = softmax(S)*V
// cast scores from Accumulator to Element
auto tSrS_ = make_tensor_like<DType>(tSrS);
fast_cast(tSrS, tSrS_);
// convert layout from gemm-I C to gemm-II A
auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout()));
auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout()));
const auto next_ni = ni + 1;
if (next_ni != n_block_max) {
produce_k_rope(next_ni);
Expand Down
8 changes: 4 additions & 4 deletions src/kernels/attention/mla_sm80_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ using namespace llm;
} else if (HEAD_DIM_V <= 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 384) { \
constexpr static int HEAD_DIM_NAME = 384; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 512) { \
constexpr static int HEAD_DIM_NAME = 512; \
constexpr static int BLK_N = 32; \
constexpr static int BLK_K = 128; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else { \
assert(false); \
Expand Down

0 comments on commit 027c8dd

Please sign in to comment.