Skip to content

Commit

Permalink
kernel: added q and kv oob handling for MLA kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Feb 26, 2025
1 parent e855f1f commit 14beb48
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 51 deletions.
2 changes: 2 additions & 0 deletions docker/Dockerfile.devel
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ ARG CUDA_VERSION=12.1
COPY ./common/install_cuda.sh install_cuda.sh
RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh
ENV DESIRED_CUDA=${CUDA_VERSION}
ENV CUDA_HOME=/usr/local/cuda
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
RUN nvcc --version

Expand Down
14 changes: 7 additions & 7 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace llm {
using namespace cute;

template <int BLK_M, int BLK_N, int ROWS_PER_THR, bool ALIBI, bool LOCAL>
template <int ROWS_PER_THR, bool ALIBI, bool LOCAL>
struct Mask {
// Fragment type for alibi slopes
using FragmentT = decltype(make_tensor<float>(Int<ROWS_PER_THR>{}));
Expand All @@ -31,15 +31,15 @@ struct Mask {
// cS_mn: ((2, MMA_M), (2, MMA_N))
template <typename IdentityS>
CUTE_HOST_DEVICE void init_alibi(IdentityS& cS_mn,
int m_block_idx,
int m_base_idx,
int kv_head_idx,
float sm_scale,
const float* alibi_slops_ptr) {
// copy alibi slopes to registers
CUTE_UNROLL
for (int i = 0; i < size<0>(cS_mn); ++i) {
const auto [m, n] = cS_mn(i, _0{});
const int q_packed_idx = m_block_idx * BLK_M + m;
const int q_packed_idx = m_base_idx + m;
const int offset = q_packed_idx % group_size_;
const int head_idx = kv_head_idx * group_size_ + offset;
alibi_slopes_(i) = alibi_slops_ptr[head_idx] / sm_scale;
Expand All @@ -50,16 +50,16 @@ struct Mask {
template <bool OOB_MASK = true, typename FragmentS, typename IdentityS>
CUTE_HOST_DEVICE void apply(FragmentS& rS_mn,
IdentityS& cS_mn,
int m_block_idx,
int n_block_idx) const {
int m_base_idx,
int n_base_idx) const {
CUTE_UNROLL
for (int i = 0; i < size<0>(rS_mn); ++i) {
const auto alibi_slope = ALIBI ? alibi_slopes_(i) : 0.0f;
CUTE_UNROLL
for (int j = 0; j < size<1>(rS_mn); ++j) {
auto [m, n] = cS_mn(i, j);
const int q_packed_idx = m_block_idx * BLK_M + m;
const int kv_idx = n_block_idx * BLK_N + n;
const int q_packed_idx = m_base_idx + m;
const int kv_idx = n_base_idx + n;

const int q_idx = q_packed_idx / group_size_ + diagonal_offset_;

Expand Down
14 changes: 9 additions & 5 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
using Mask = Mask<kBlockM, kBlockN, kRowsPerThr, ALIBI, LOCAL>;
using Mask = Mask<kRowsPerThr, ALIBI, LOCAL>;

Softmax softmax(sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);
if constexpr (ALIBI) {
mask.init_alibi(
tScS_mn, m_block_idx, kv_head_idx, sm_scale, params.alibi_slopes_ptr);
mask.init_alibi(tScS_mn,
m_block_idx * kBlockM,
kv_head_idx,
sm_scale,
params.alibi_slopes_ptr);
}

CUTE_NO_UNROLL
Expand Down Expand Up @@ -376,10 +379,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
}

if (i < n_oob_mask) {
mask.apply(tSrS_mn, tScS_mn, m_block_idx, n_block_idx);
mask.apply(
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
} else {
mask.apply</*OOB_MASK=*/false>(
tSrS_mn, tScS_mn, m_block_idx, n_block_idx);
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
}
softmax.rescale(tSrS_mn, tOrO_mn);

Expand Down
100 changes: 72 additions & 28 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@

namespace llm {

template <typename Traits,
typename Params,
bool EVEN_K,
bool ALIBI,
bool SOFT_CAP,
bool LOCAL>
template <typename Traits, typename Params>
__global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
__grid_constant__ const Params params) {
using namespace cute;
Expand Down Expand Up @@ -89,9 +84,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// K_ROPE: (kv_len, ROPE_HEAD_DIM)
auto [KV, K_ROPE] = tile.template get_kv_tile<DType>(batch_idx);

const int q_len = size<0>(Q) / group_size;
const int q_packed_len = size<0>(Q);
const int q_len = q_packed_len / group_size;
const int kv_len = size<0>(KV);
const int sliding_window = kv_len;

if (m_block_idx * kBlockM >= size<0>(Q)) {
// m out of bound, return
Expand Down Expand Up @@ -180,42 +175,84 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for q
GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx);

// coordinate tensor for oob handling
// (BLK_M, BLK_K) -> (blk_m, blk_k)
Tensor cQ = make_identity_tensor(Shape<_BLK_M, _BLK_K>{});
Tensor tCcQ = gmem_thr_copy_Q.partition_S(cQ(_, _));
auto max_coord_Q = make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK);

auto produce_q = [&](int step) {
// gQ/sQ: (BLK_M, BLK_K, STEPS)
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step));
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step));
cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ);
safe_copy</*EVEN_MN=*/false,
/*EVEN_K=*/true,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_Q, tCgQ, tCsQ, tCcQ, max_coord_Q);
};

// g2s tiled copy for q_rope
GmemTiledCopyQRope gmem_tiled_copy_Q_rope;
auto gmem_thr_copy_Q_rope = gmem_tiled_copy_Q_rope.get_slice(tidx);

// (BLK_M, ROPE_HEAD_DIM) -> (blk_m, rope_head_dim)
Tensor cQ_rope = make_identity_tensor(Shape<_BLK_M, _ROPE_HEAD_DIM>{});
Tensor tCcQ_rope = gmem_thr_copy_Q_rope.partition_S(cQ_rope);

auto produce_q_rope = [&]() {
auto tCgQ_rope = gmem_thr_copy_Q_rope.partition_S(gQ_rope);
auto tCsQ_rope = gmem_thr_copy_Q_rope.partition_D(sQ_rope);
cute::copy(gmem_tiled_copy_Q_rope, tCgQ_rope, tCsQ_rope);
auto max_coord =
make_coord(q_packed_len - m_block_idx * kBlockM, kRopeHeadDim);
safe_copy</*EVEN_MN=*/false,
/*EVEN_K=*/true,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_Q_rope, tCgQ_rope, tCsQ_rope, tCcQ_rope, max_coord);
};

// g2s tiled copy for kv
GmemTiledCopyKV gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx);

// (BLK_N, BLK_K, STEPS) -> (blk_n, head_dim)
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _BLK_K>{});
Tensor tCcKV = gmem_thr_copy_KV.partition_S(cKV);

auto produce_kv = [&](int ni, int step, int stage) {
// gKV: (BLK_N, BLK_K, n, STEPS)
// sK: (BLK_N, BLK_K, STEPS, STAGES)
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step));
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage));
cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV);
auto max_coord = make_coord(kv_len - ni * kBlockN, kBlockK);
safe_copy</*EVEN_MN=*/false,
/*EVEN_K=*/true,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_KV, tCgKV, tCsKV, tCcKV, max_coord);
};

// g2s tiled copy for k_rope
GmemTiledCopyKRope gmem_tiled_copy_K_rope;
auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx);

// (BLK_N, ROPE_HEAD_DIM) -> (blk_n, rope_head_dim)
Tensor cK_rope = make_identity_tensor(Shape<_BLK_N, _ROPE_HEAD_DIM>{});
Tensor tKcK_rope = gmem_thr_copy_K_rope.partition_S(cK_rope);

auto produce_k_rope = [&](int ni, int stage) {
// gK_rope: (BLK_N, ROPE_HEAD_DIM, n)
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES)
auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni));
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage));
cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope);
auto tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage));
auto max_coord = make_coord(kv_len - ni * kBlockN, kRopeHeadDim);
safe_copy</*EVEN_MN=*/false,
/*EVEN_K=*/true,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope, tKcK_rope, max_coord);
};

// GEMM-I: S = Q@K.T
Expand Down Expand Up @@ -382,9 +419,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx);

auto tCsO = gmem_thr_copy_O.partition_S(sO);
auto tCgO = gmem_thr_copy_O.partition_D(gO);
cute::copy(gmem_tiled_copy_O, tCsO, tCgO);
// (BLK_M, BLK_K) -> (blk_m, blk_k)
auto cO = make_identity_tensor(Shape<_BLK_M, _BLK_K>{});
auto tCcO = gmem_thr_copy_Q.partition_S(cO);
auto max_coord_O =
make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK);

CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
auto tCsO = gmem_thr_copy_O.partition_S(sO(_, _, step));
auto tCgO = gmem_thr_copy_O.partition_D(gO(_, _, step));

safe_copy</*EVEN_MN=*/false,
/*EVEN_K=*/true,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/false>(
gmem_tiled_copy_O, tCsO, tCgO, tCcO, max_coord_O);
}
};

// output accumulator: (MMA,MMA_M,MMA_K,STEPS)
Expand Down Expand Up @@ -441,13 +492,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
using Mask = Mask<kBlockM, kBlockN, kRowsPerThr, ALIBI, LOCAL>;
using Mask = Mask<kRowsPerThr, /*ALIBI=*/false, /*LOCAL=*/false>;

Softmax softmax(params.sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);
Mask mask(q_len, kv_len, group_size, /*sliding_window=*/kv_len);

int stage = 0;

CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
clear(tSrS);
Expand All @@ -470,7 +520,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
}

// apply mask + softmax
mask.apply(tSrS_mn, tScS_mn, m_block_idx, ni);
mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN);
softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax);

// save tSrS from rmem to smem
Expand Down Expand Up @@ -520,21 +570,15 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
epilogue(tOrO);
}

template <typename Traits,
typename Params,
bool EVEN_K = false,
bool ALIBI = false,
bool SOFT_CAP = false,
bool LOCAL = false>
template <typename Traits, typename Params>
void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) {
const auto batch_size = params.batch_size;
const auto max_q_packed_len = params.max_q_len * params.n_heads;

const auto smem_size = Traits::kSmemSize;
// print("smem_size: %d\n", smem_size);

auto mla_kernel =
mla_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
auto mla_kernel = mla_kernel_sm80<Traits, Params>;
C10_CUDA_CHECK(cudaFuncSetAttribute(
mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// TODO: support persistent kernels
Expand Down
22 changes: 11 additions & 11 deletions src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ namespace llm {

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 128) { \
if (HEAD_DIM_V == 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
} else if (HEAD_DIM_V == 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 32; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 512) { \
} else if (HEAD_DIM_V == 512) { \
constexpr static int HEAD_DIM_NAME = 512; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 16; \
Expand All @@ -43,7 +43,7 @@ namespace llm {

#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \
[&] { \
if (ROPE_HEAD_DIM_V <= 64) { \
if (ROPE_HEAD_DIM_V == 64) { \
constexpr static int ROPE_HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else { \
Expand Down Expand Up @@ -159,13 +159,13 @@ TEST_P(MLAKernelTest, MLA) {
INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(64), // q_len
::testing::Values(64, 128, 1024), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm

0 comments on commit 14beb48

Please sign in to comment.