Skip to content

Commit

Permalink
kernel: fix kv oob issue with stages > 1 and added more unittests for…
Browse files Browse the repository at this point in the history
… paged MLA (#416)
  • Loading branch information
guocuimi authored Feb 28, 2025
1 parent 4b8114b commit 1cc659f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
21 changes: 15 additions & 6 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,22 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cp_async_fence();
}
}
stage == 0 ? produce_k_rope(ni, stage) : produce_k_rope_no_oob(ni, stage);
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
stage == 0 ? produce_kv(ni, step, stage)
: produce_kv_no_oob(ni, step, stage);
// handle oob kv
if (ni >= n_block_min) {
stage == 0 ? produce_k_rope(ni, stage) : produce_k_rope_no_oob(ni, stage);
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
stage == 0 ? produce_kv(ni, step, stage)
: produce_kv_no_oob(ni, step, stage);
cp_async_fence();
}
} else {
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
cp_async_fence();
}
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ TEST_P(MLAKernelPagedKVTest, PageKV) {
}
}

// construct non-contiguous query, key and value
// generate query, key and value
// generate q, kv, q_rope, k_rope
torch::Tensor q = torch::rand({n_q_tokens, n_heads, head_dim}, options);
const auto n_slots = total_blocks * block_size;
torch::Tensor kv_cache = torch::rand({n_slots, head_dim}, options);
Expand Down Expand Up @@ -270,14 +269,15 @@ TEST_P(MLAKernelPagedKVTest, PageKV) {
INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelPagedKVTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 8, 64), // block_size
::testing::Values(1, 125), // max_q_len
::testing::Values(127, 1000), // max_kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(512), // head_dim
::testing::Values(64) // rope_head_dim
::testing::Combine(::testing::Values(torch::kHalf,
torch::kBFloat16), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 8, 64), // block_size
::testing::Values(1, 125), // max_q_len
::testing::Values(127, 1000), // max_kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm
7 changes: 6 additions & 1 deletion src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ TEST_P(MLAKernelTest, MLA) {
n_heads,
head_dim,
rope_head_dim] = GetParam();
// skip invalid test cases
if (kv_len < q_len) {
return;
}

const auto options = torch::dtype(dtype).device(torch::kCUDA);

// q: [batch, q_len, n_heads, head_dim]
Expand Down Expand Up @@ -182,7 +187,7 @@ INSTANTIATE_TEST_SUITE_P(
torch::kBFloat16), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
::testing::Values(1, 30, 287, 1000), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
Expand Down

0 comments on commit 1cc659f

Please sign in to comment.