From 1cc659f9623d48e38a4feeb4a74401b9ab11f239 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 28 Feb 2025 10:32:50 -1000 Subject: [PATCH] kernel: fix kv oob issue with stages > 1 and added more unittests for paged MLA (#416) --- src/kernels/attention/mla_kernel_sm80.cuh | 21 +++++++++++++------ .../attention/mla_kernel_sm80_pagedkv_test.cu | 20 +++++++++--------- src/kernels/attention/mla_kernel_sm80_test.cu | 7 ++++++- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index bcde5e43..5d8efb95 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -545,13 +545,22 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( cp_async_fence(); } } - stage == 0 ? produce_k_rope(ni, stage) : produce_k_rope_no_oob(ni, stage); - cp_async_fence(); - CUTE_UNROLL - for (int step = 0; step < kSteps; ++step) { - stage == 0 ? produce_kv(ni, step, stage) - : produce_kv_no_oob(ni, step, stage); + // handle oob kv + if (ni >= n_block_min) { + stage == 0 ? produce_k_rope(ni, stage) : produce_k_rope_no_oob(ni, stage); cp_async_fence(); + CUTE_UNROLL + for (int step = 0; step < kSteps; ++step) { + stage == 0 ? produce_kv(ni, step, stage) + : produce_kv_no_oob(ni, step, stage); + cp_async_fence(); + } + } else { + cp_async_fence(); + CUTE_UNROLL + for (int step = 0; step < kSteps; ++step) { + cp_async_fence(); + } } } diff --git a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu index caff8892..39294582 100644 --- a/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu @@ -210,8 +210,7 @@ TEST_P(MLAKernelPagedKVTest, PageKV) { } } - // construct non-contiguous query, key and value - // generate query, key and value + // generate q, kv, q_rope, k_rope torch::Tensor q = torch::rand({n_q_tokens, n_heads, head_dim}, options); const auto n_slots = total_blocks * block_size; torch::Tensor kv_cache = torch::rand({n_slots, head_dim}, options); @@ -270,14 +269,15 @@ TEST_P(MLAKernelPagedKVTest, PageKV) { INSTANTIATE_TEST_SUITE_P( MLA, MLAKernelPagedKVTest, - ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype - ::testing::Values(1, 2, 4), // batch_size - ::testing::Values(1, 8, 64), // block_size - ::testing::Values(1, 125), // max_q_len - ::testing::Values(127, 1000), // max_kv_len - ::testing::Values(1, 8, 128), // n_heads - ::testing::Values(512), // head_dim - ::testing::Values(64) // rope_head_dim + ::testing::Combine(::testing::Values(torch::kHalf, + torch::kBFloat16), // q_dtype + ::testing::Values(1, 2, 4), // batch_size + ::testing::Values(1, 8, 64), // block_size + ::testing::Values(1, 125), // max_q_len + ::testing::Values(127, 1000), // max_kv_len + ::testing::Values(1, 8, 128), // n_heads + ::testing::Values(128, 256, 512), // head_dim + ::testing::Values(64) // rope_head_dim )); } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index 56b7396f..2b6c4098 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -149,6 +149,11 @@ TEST_P(MLAKernelTest, MLA) { n_heads, head_dim, rope_head_dim] = GetParam(); + // skip invalid test cases + if (kv_len < q_len) { + return; + } + const auto options = torch::dtype(dtype).device(torch::kCUDA); // q: [batch, q_len, n_heads, head_dim] @@ -182,7 +187,7 @@ INSTANTIATE_TEST_SUITE_P( torch::kBFloat16), // q_dtype ::testing::Values(1, 2, 4, 10), // batch_size ::testing::Values(1, 62, 125), // q_len - ::testing::Values(127, 287, 1000), // kv_len + ::testing::Values(1, 30, 287, 1000), // kv_len ::testing::Values(1, 8, 128), // n_heads ::testing::Values(128, 256, 512), // head_dim ::testing::Values(64) // rope_head_dim