Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernel: fix mask bugs for MLA #408

Merged
merged 4 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@
namespace llm {
using namespace cute;

template <int BLK_M,
int BLK_N,
int ROWS_PER_MMA,
int MMA_M,
bool ALIBI,
bool LOCAL>
template <int BLK_M, int BLK_N, int ROWS_PER_THR, bool ALIBI, bool LOCAL>
struct Mask {
// Fragment type for alibi slopes: (2, MMA_M)
using FragmentT =
decltype(make_tensor<float>(Shape<Int<ROWS_PER_MMA>, Int<MMA_M>>{}));
// Fragment type for alibi slopes
using FragmentT = decltype(make_tensor<float>(Int<ROWS_PER_THR>{}));

int q_len_;
int kv_len_;
Expand Down
6 changes: 3 additions & 3 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
thr_mma.partition_C(make_identity_tensor(Shape<_BLK_M, _BLK_N>{}));
auto tScS_mn = make_tensor(tScS.data(), Layout::to_mn(tScS.layout()));

constexpr int kMMA_M = size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M>;
using Mask = Mask<kBlockM, kBlockM, kRowsPerMMA, kMMA_M, ALIBI, LOCAL>;
constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
using Mask = Mask<kBlockM, kBlockN, kRowsPerThr, ALIBI, LOCAL>;

Softmax softmax(sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);
Expand Down
45 changes: 19 additions & 26 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

// ProblemShape
// Q/O: (q_packed_len, HEAD_DIM)
// KV: (kv_len, HEAD_DIM)
// Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM)
// Q_ROPE: (q_packed_len, ROPE_HEAD_DIM)
auto [Q, Q_ROPE, O] = tile.template get_qo_tile<DType>(batch_idx);
// KV: (kv_len, HEAD_DIM)
// K_ROPE: (kv_len, ROPE_HEAD_DIM)
auto [KV, K_ROPE] = tile.template get_kv_tile<DType>(batch_idx);

const int q_len = size<0>(Q) / group_size;
Expand Down Expand Up @@ -148,11 +149,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
Tensor sRowsum =
make_tensor(make_smem_ptr(row_sync_smem), SmemLayoutRowsum{});

// thread layout: (32, 8), each thread process 2 rows
// (store_idx, load_idx) = (0, 64), (1, 65), ...
// reduce rowmax/rowsum accross 2 warps via shared memory
// thread layout: (32, (4, 2)), each thread process 2 rows
// (store_idx, load_idx) = (0, 64) or (1, 65), ...
const int row_store_idx = tidx / 4 * 2;
const int row_load_idx = row_store_idx ^ kBlockM;
// reduce rowmax accross 2 warps
auto reduce_rowmax = [&](auto& row_max) {
CUTE_UNROLL
for (int i = 0; i < size(row_max); ++i) {
Expand All @@ -164,8 +165,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
row_max(i) = max(row_max(i), sRowmax(row_load_idx + i));
}
};

// reduce rowsum accross 2 warps
auto reduce_rowsum = [&](auto& row_sum) {
CUTE_UNROLL
for (int i = 0; i < size(row_sum); ++i) {
Expand Down Expand Up @@ -218,16 +217,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope);
};

// GEMM-I: S = Q@K.T
TiledMma_QK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_slice(tidx);
// GEMM-I: S = Q@K.T
// sQ/sK: (BLK_M, BLK_K, STAGES)
auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{}));
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}));
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope);
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope);

// s2r tiled copy for q
// s2r tiled copy for q/q_rope
SmemTiledCopyQ smem_tiled_copy_Q;
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx);
// (CPY, CPY_M, CPY_K, STAGES)
Expand All @@ -240,7 +239,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// (CPY, CPY_M, CPY_K)
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);

// s2r tiled copy for k
// s2r tiled copy for k/k_rope
SmemTiledCopyK smem_tiled_copy_K;
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
Expand Down Expand Up @@ -297,11 +296,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// GEMM-II: O = softmax(S)@V
TiledMma_PV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_slice(tidx);
// sS: (BLK_M, BLK_N)
// (MMA, MMA_M, MMA_K)
// sP: (BLK_M, BLK_N)
auto tOrP = thr_mma_pv.partition_fragment_A(sP);
// sVt: (BLK_K, BLK_N, STAGES)
// (MMA, MMA_N, MMA_K)
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}));

// s2r tiled copy for p
Expand All @@ -321,12 +318,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt);

// O = P*V = softmax(S)*V
// tOrS: (MMA,MMA_M,MMA_K)
// tOrO: (MMA,MMA_M,MMA_K,STAGES)
auto compute_pv = [&](auto& tOrO, int s) {
// (MMA,MMA_M,MMA_N, STAGES)
auto tOrO_s = tOrO(_, _, _, s);

// (CPY, CPY_N, CPY_K, STAGES)
auto tCsVt_s = tCsVt(_, _, _, s);
cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{}));
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{}));
Expand All @@ -346,7 +340,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// r2s tiled copy for S/P
SmemTiledCopyS smem_tiled_copy_S;
auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(tidx);

auto store_s_to_smem = [&](const auto& tSrS) {
// cast Accumulator to Element type
auto tSrS_ = make_tensor_like<DType>(tSrS);
Expand Down Expand Up @@ -376,7 +369,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tCsO = smem_thr_copy_O.partition_D(sO_s);
cute::copy(smem_tiled_copy_O, tCrO, tCsO);
}
// wait for smem copy done before gmem copy

__syncthreads();

// 2. copy output from smem to gmem
Expand All @@ -388,7 +381,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(gmem_tiled_copy_O, tCsO, tCgO);
};

// output accumulator: (MMA, MMA_M, MMA_K, STAGES)
// output accumulator: (MMA,MMA_M,MMA_K,STAGES)
auto tOrO =
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{});
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout()));
Expand Down Expand Up @@ -416,16 +409,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// ############### Mainloop ###############
// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrS = partition_fragment_C(tiled_mma_qk, Shape<_BLK_M, _BLK_N>{});
auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_mn(tSrS.layout()));

// identity tensor for score accumulator
auto tScS =
thr_mma_qk.partition_C(make_identity_tensor(Shape<_BLK_M, _BLK_N>{}));
auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_mn(tSrS.layout()));
auto tScS_mn = make_tensor(tScS.data(), Layout::to_mn(tScS.layout()));

constexpr int kMMA_M = size<1>(tOrO);
using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M>;
using Mask = Mask<kBlockM, kBlockM, kRowsPerMMA, kMMA_M, ALIBI, LOCAL>;
constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
using Mask = Mask<kBlockM, kBlockN, kRowsPerThr, ALIBI, LOCAL>;

Softmax softmax(params.sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);
Expand Down Expand Up @@ -468,6 +459,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
__syncthreads();

produce_kv(next_ni, s);
cp_async_fence();
}
Expand Down
50 changes: 19 additions & 31 deletions src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@ namespace llm {

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 128) { \
if (HEAD_DIM_V <= 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int BLK_K = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_N = 32; \
constexpr static int BLK_K = 128; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 512) { \
Expand Down Expand Up @@ -136,44 +130,38 @@ TEST_P(MLAKernelTest, MLA) {
n_heads,
head_dim,
rope_head_dim] = GetParam();
// const auto head_dim = kv_lora_rank + rope_head_dim;
const auto options = torch::dtype(dtype).device(torch::kCUDA);

// q: [batch, len, n_heads, head_dim]
// kv: [batch, len, head_dim]
const auto q =
torch::randn({batch_size, q_len, n_heads, head_dim}, options) * 0.001;
const auto kv = torch::randn({batch_size, kv_len, head_dim}, options) * 0.001;
// q: [batch, q_len, n_heads, head_dim]
// kv: [batch, kv_len, head_dim]
const auto q = torch::randn({batch_size, q_len, n_heads, head_dim}, options);
const auto kv = torch::randn({batch_size, kv_len, head_dim}, options);

// q_rope: [batch, len, n_heads, rope_head_dim]
// kv_rope: [batch, len, rope_head_dim]
// q_rope: [batch, q_len, n_heads, rope_head_dim]
// kv_rope: [batch, kv_len, rope_head_dim]
const auto q_rope =
torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options) * 0.01;
torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options);
const auto k_rope =
torch::randn({batch_size, kv_len, rope_head_dim}, options) * 0.01;
torch::randn({batch_size, kv_len, rope_head_dim}, options);

const float sm_scale = 1.0 / sqrt(head_dim + rope_head_dim);

auto ref_out = mla_batch_ref(q, kv, q_rope, k_rope, sm_scale);
auto out = mla_sm80(q, kv, q_rope, k_rope, sm_scale);
// std::cerr << "max diff: " << (ref_out - out).abs().max() << std::endl;
if (head_dim >= 512) {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/3e-2, /*atol=*/3e-2));
} else {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}

INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(64), // q_len
::testing::Values(64, 128), // kv_len
::testing::Values(1, 8, 24, 128), // n_heads
::testing::Values(64, 128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(64), // q_len
::testing::Values(64, 128), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm