Skip to content

Commit

Permalink
kernel: added paged kv support for MLA kernel (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Feb 28, 2025
1 parent b8cba27 commit 4b8114b
Show file tree
Hide file tree
Showing 7 changed files with 503 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cc_test(
SRCS
mla_traits_test.cpp
mla_kernel_sm80_test.cu
mla_kernel_sm80_pagedkv_test.cu
DEPS
:attention.template
absl::random_random
Expand Down
16 changes: 15 additions & 1 deletion src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// K_ROPE: (kv_len, ROPE_HEAD_DIM)
auto [KV, K_ROPE] = tile.template get_kv_tile<DType>(batch_idx);

#if 0
if (thread0()) {
print("Q: "); print(Q); print("\n");
print("Q_ROPE: "); print(Q_ROPE); print("\n");
print("O: "); print(O); print("\n");
print("KV: "); print(KV); print("\n");
print("K_ROPE: "); print(K_ROPE); print("\n");
print("m_block_idx: %d, batch_idx: %d, group_size: %d\n",
m_block_idx,
batch_idx,
group_size);
}
#endif

const int q_packed_len = size<0>(Q);
const int q_len = q_packed_len / group_size;
const int kv_len = size<0>(KV);
Expand Down Expand Up @@ -642,7 +656,7 @@ void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) {
const auto max_q_packed_len = params.max_q_len * params.n_heads;

const auto smem_size = sizeof(MLASharedStorage<Traits>);
// print("smem_size: %d\n", smem_size);
// print("smem_size: %d, %d\n", smem_size, Traits::kSmemSize);

auto mla_kernel = mla_kernel_sm80<Traits, Params>;
C10_CUDA_CHECK(cudaFuncSetAttribute(
Expand Down
283 changes: 283 additions & 0 deletions src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#include <absl/random/random.h>
#include <gtest/gtest.h>
#include <torch/torch.h>

#include "cute/layout.hpp"
#include "mla_kernel_sm80.cuh"
#include "mla_params.h"
#include "mla_ref.h"
#include "mla_traits_sm80.h"

namespace llm {

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V == 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V == 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 32; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V == 512) { \
constexpr static int HEAD_DIM_NAME = 512; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 16; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 1; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \
[&] { \
if (ROPE_HEAD_DIM_V == 64) { \
constexpr static int ROPE_HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

#define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \
[&] { \
if (TORCH_DTYPE == torch::kHalf) { \
using TYPE_NAME = cute::half_t; \
return __VA_ARGS__(); \
} else if (TORCH_DTYPE == torch::kBFloat16) { \
using TYPE_NAME = cute::bfloat16_t; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor mla_pagedkv_sm80(
torch::Tensor q, // [q_seq_len, n_heads, head_dim]
torch::Tensor kv_cache, // [n_slots, head_dim]
torch::Tensor q_rope, // [q_seq_len, n_heads, rope_head_dim]
torch::Tensor k_rope_cache, // [n_slots, rope_head_dim]
torch::Tensor q_cu_lens, // [batch_size+1]
torch::Tensor kv_cu_lens, // [batch_size+1]
torch::Tensor block_table, // [n_blocks]
torch::Tensor block_cu_lens, // [batch_size+1]
int block_size,
int max_q_len,
float sm_scale) {
const auto batch_size = q_cu_lens.size(0) - 1;
const auto n_heads = q.size(-2);
const auto head_dim = q.size(-1);
const auto rope_head_dim = q_rope.size(-1);

auto out = torch::empty_like(q);

// construct attention params
MLAPagedKVParams params;
params.q_ptr = q.const_data_ptr();
params.q_stride = make_stride(q.stride(0), q.stride(1));
params.kv_ptr = kv_cache.const_data_ptr();
params.kv_stride = make_stride(kv_cache.stride(0));
params.q_rope_ptr = q_rope.const_data_ptr();
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1));
params.k_rope_ptr = k_rope_cache.const_data_ptr();
params.k_rope_stride = make_stride(k_rope_cache.stride(0));

params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));

params.batch_size = batch_size;
params.max_q_len = max_q_len;
params.n_heads = n_heads;
params.head_dim = head_dim;
params.rope_head_dim = rope_head_dim;
params.sm_scale = sm_scale;

params.q_cu_lens = q_cu_lens.const_data_ptr<int32_t>();
params.kv_cu_lens = kv_cu_lens.const_data_ptr<int32_t>();

params.block_table = block_table.const_data_ptr<int32_t>();
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
params.block_size = block_size;

params.normalize();

DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] {
using Traits = MLATraitsSM80<DTYPE,
HEAD_DIM,
ROPE_HEAD_DIM,
BLK_M,
BLK_N,
BLK_K,
STAGES>;

launch_mla_kernel_sm80<Traits>(params, nullptr);
});
});
});
return out;
}

} // namespace

class MLAKernelPagedKVTest
: public ::testing::TestWithParam<std::tuple<torch::ScalarType /*q_dtype*/,
int64_t /*batch_size*/,
int64_t /*block_size*/,
int64_t /*q_len*/,
int64_t /*kv_len*/,
int64_t /*n_heads*/,
int64_t /*head_dim*/,
int64_t /*rope_head_dim*/>> {
public:
void SetUp() override {
// Set random seed for test stability
torch::manual_seed(0);
}
};

TEST_P(MLAKernelPagedKVTest, PageKV) {
const auto [dtype,
batch_size,
block_size,
max_q_len,
max_kv_len,
n_heads,
head_dim,
rope_head_dim] = GetParam();

const auto options = torch::dtype(dtype).device(torch::kCUDA);

std::vector<int32_t> block_table_vec;
std::vector<int32_t> block_cu_lens_vec = {0};
std::vector<int> slot_ids;

const int32_t total_blocks = (max_kv_len * batch_size) / block_size + 2;
// random generate seq lens with size in [1, max_seq_len]
std::vector<int32_t> q_cu_lens_vec = {0};
std::vector<int32_t> kv_cu_lens_vec = {0};
int32_t n_kv_tokens = 0;
int32_t n_q_tokens = 0;
absl::BitGen gen;
for (int i = 0; i < batch_size; ++i) {
// q_len: [1, q_max_seq_len]
const int32_t q_len =
absl::Uniform<int>(absl::IntervalClosedClosed, gen, 1, max_q_len);
n_q_tokens += q_len;
q_cu_lens_vec.push_back(n_q_tokens);

// kv_len >= q_len
int32_t kv_len = q_len;
if (q_len < max_kv_len) {
// sample kv_len from [q_len, kv_max_seq_len]
kv_len = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, q_len, max_kv_len);
}
n_kv_tokens += kv_len;
kv_cu_lens_vec.push_back(n_kv_tokens);
assert(kv_len >= q_len);

// assign blocks for each sequence
const int32_t n_blocks = (kv_len + block_size - 1) / block_size;
std::vector<int32_t> block_ids;
block_ids.reserve(n_blocks);
for (int j = 0; j < n_blocks; ++j) {
// random assign block size
const int32_t id = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, 1, total_blocks - 1);
// put first slot id of each block into block_table
block_ids.push_back(id * block_size);
}
block_table_vec.insert(
block_table_vec.end(), block_ids.begin(), block_ids.end());
block_cu_lens_vec.push_back(block_table_vec.size());

for (int j = 0; j < kv_len; ++j) {
const int32_t slot_base = block_ids[j / block_size];
const int32_t block_offset = j % block_size;
slot_ids.push_back(slot_base + block_offset);
}
}

// construct non-contiguous query, key and value
// generate query, key and value
torch::Tensor q = torch::rand({n_q_tokens, n_heads, head_dim}, options);
const auto n_slots = total_blocks * block_size;
torch::Tensor kv_cache = torch::rand({n_slots, head_dim}, options);

torch::Tensor q_rope =
torch::rand({n_q_tokens, n_heads, rope_head_dim}, options);
torch::Tensor k_rope_cache = torch::rand({n_slots, rope_head_dim}, options);

torch::Tensor q_cu_lens = torch::tensor(
q_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Tensor kv_cu_lens = torch::tensor(
kv_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));

torch::Tensor block_table = torch::tensor(
block_table_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Tensor block_cu_lens = torch::tensor(
block_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));

// get combined key and value
std::vector<torch::Tensor> kvs;
kvs.reserve(slot_ids.size());
std::vector<torch::Tensor> k_ropes;
k_ropes.reserve(slot_ids.size());
for (int slot_id : slot_ids) {
// kv = kv_cache[slot_idx, :, :]
kvs.push_back(kv_cache[slot_id]);
k_ropes.push_back(k_rope_cache[slot_id]);
}
torch::Tensor kv = torch::stack(kvs, /*dim=*/0);
torch::Tensor k_rope = torch::stack(k_ropes, /*dim=*/0);

const float sm_scale = 1.0 / sqrt(head_dim + rope_head_dim);

auto ref_out =
mla_varlen_ref(q, kv, q_rope, k_rope, q_cu_lens, kv_cu_lens, sm_scale);
auto out = mla_pagedkv_sm80(q,
kv_cache,
q_rope,
k_rope_cache,
q_cu_lens,
kv_cu_lens,
block_table,
block_cu_lens,
block_size,
max_q_len,
sm_scale);

// std::cerr << "max diff: " << (ref_out - out).abs().max() << std::endl;
if (dtype == torch::kBFloat16) {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2));
} else {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}
}

INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelPagedKVTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 8, 64), // block_size
::testing::Values(1, 125), // max_q_len
::testing::Values(127, 1000), // max_kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm
2 changes: 1 addition & 1 deletion src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <iostream>

#include "cute/numeric/numeric_types.hpp"
#include "mla_kernel_sm80.cuh" // IWYU pragma: keep
#include "mla_kernel_sm80.cuh"
#include "mla_params.h"
#include "mla_ref.h"
#include "mla_traits_sm80.h"
Expand Down
Loading

0 comments on commit 4b8114b

Please sign in to comment.