Skip to content

Commit

Permalink
kernel: added paged kv support for MLA kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Feb 27, 2025
1 parent b8cba27 commit 40607dd
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cc_test(
SRCS
mla_traits_test.cpp
mla_kernel_sm80_test.cu
mla_kernel_sm80_pagedkv_test.cu
DEPS
:attention.template
absl::random_random
Expand Down
279 changes: 279 additions & 0 deletions src/kernels/attention/mla_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
#include <ATen/core/TensorBody.h>
#include <absl/random/random.h>
#include <gtest/gtest.h>
#include <torch/torch.h>

#include "cute/layout.hpp"
#include "mla_kernel_sm80.cuh"
#include "mla_params.h"
#include "mla_ref.h"
#include "mla_traits_sm80.h"

namespace llm {

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V == 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 64; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V == 256) { \
constexpr static int HEAD_DIM_NAME = 256; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 32; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 2; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V == 512) { \
constexpr static int HEAD_DIM_NAME = 512; \
constexpr static int BLK_M = 64; \
constexpr static int BLK_N = 16; \
constexpr static int BLK_K = 128; \
constexpr static int STAGES = 1; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

#define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \
[&] { \
if (ROPE_HEAD_DIM_V == 64) { \
constexpr static int ROPE_HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

#define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \
[&] { \
if (TORCH_DTYPE == torch::kHalf) { \
using TYPE_NAME = cute::half_t; \
return __VA_ARGS__(); \
} else if (TORCH_DTYPE == torch::kBFloat16) { \
using TYPE_NAME = cute::bfloat16_t; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor mla_pagedkv_sm80(
torch::Tensor q, // [q_seq_len, n_heads, head_dim]
torch::Tensor kv_cache, // [n_slots, head_dim]
torch::Tensor q_rope, // [q_seq_len, n_heads, rope_head_dim]
torch::Tensor k_rope_cache, // [n_slots, rope_head_dim]
torch::Tensor q_cu_lens, // [batch_size+1]
torch::Tensor kv_cu_lens, // [batch_size+1]
torch::Tensor block_table, // [n_blocks]
torch::Tensor block_cu_lens, // [batch_size+1]
int block_size,
int max_q_len,
float sm_scale) {
const auto batch_size = q_cu_lens.size(0) - 1;
const auto n_heads = q.size(-2);
const auto head_dim = q.size(-1);
const auto rope_head_dim = q_rope.size(-1);

auto out = torch::empty_like(q);

// construct attention params
MLAPagedKVParams params;
params.q_ptr = q.const_data_ptr();
params.q_stride = make_stride(q.stride(0), q.stride(1));
params.kv_ptr = kv_cache.const_data_ptr();
params.kv_stride = make_stride(kv_cache.stride(0));
params.q_rope_ptr = q_rope.const_data_ptr();
params.q_rope_stride = make_stride(q_rope.stride(0), q_rope.stride(1));
params.k_rope_ptr = k_rope_cache.const_data_ptr();
params.k_rope_stride = make_stride(k_rope_cache.stride(0));

params.o_ptr = out.mutable_data_ptr();
params.o_stride = make_stride(out.stride(0), out.stride(1));

params.batch_size = batch_size;
params.max_q_len = max_q_len;
params.n_heads = n_heads;
params.head_dim = head_dim;
params.rope_head_dim = rope_head_dim;
params.sm_scale = sm_scale;

params.q_cu_lens = q_cu_lens.const_data_ptr<int32_t>();
params.kv_cu_lens = kv_cu_lens.const_data_ptr<int32_t>();

params.block_table = block_table.const_data_ptr<int32_t>();
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
params.block_size = block_size;

params.normalize();

DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] {
using Traits = MLATraitsSM80<DTYPE,
HEAD_DIM,
ROPE_HEAD_DIM,
BLK_M,
BLK_N,
BLK_K,
STAGES>;

launch_mla_kernel_sm80<Traits>(params, nullptr);
});
});
});
return out;
}

} // namespace

class MLAKernelPagedKVTest
: public ::testing::TestWithParam<std::tuple<torch::ScalarType /*q_dtype*/,
int64_t /*batch_size*/,
int64_t /*block_size*/,
int64_t /*q_len*/,
int64_t /*kv_len*/,
int64_t /*n_heads*/,
int64_t /*head_dim*/,
int64_t /*rope_head_dim*/>> {
public:
void SetUp() override {
// Set random seed for test stability
torch::manual_seed(0);
}
};

TEST_P(MLAKernelPagedKVTest, PageKV) {
const auto [dtype,
batch_size,
block_size,
max_q_len,
max_kv_len,
n_heads,
head_dim,
rope_head_dim] = GetParam();

const auto options = torch::dtype(dtype).device(torch::kCUDA);

std::vector<int32_t> block_table_vec;
std::vector<int32_t> block_cu_lens_vec = {0};
std::vector<int> slot_ids;

const int32_t total_blocks = (max_kv_len * batch_size) / block_size + 2;
// random generate seq lens with size in [1, max_seq_len]
std::vector<int32_t> q_cu_lens_vec = {0};
std::vector<int32_t> kv_cu_lens_vec = {0};
int32_t n_kv_tokens = 0;
int32_t n_q_tokens = 0;
absl::BitGen gen;
for (int i = 0; i < batch_size; ++i) {
// q_len: [1, q_max_seq_len]
const int32_t q_len =
absl::Uniform<int>(absl::IntervalClosedClosed, gen, 1, max_q_len);
n_q_tokens += q_len;
q_cu_lens_vec.push_back(n_q_tokens);

// kv_len >= q_len
int32_t kv_len = q_len;
if (q_len < max_kv_len) {
// sample kv_len from [q_len, kv_max_seq_len]
kv_len = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, q_len, max_kv_len);
}
n_kv_tokens += kv_len;
kv_cu_lens_vec.push_back(n_kv_tokens);
assert(kv_len >= q_len);

// assign blocks for each sequence
const int32_t n_blocks = (kv_len + block_size - 1) / block_size;
std::vector<int32_t> block_ids;
block_ids.reserve(n_blocks);
for (int j = 0; j < n_blocks; ++j) {
// random assign block size
const int32_t id = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, 1, total_blocks - 1);
// put first slot id of each block into block_table
block_ids.push_back(id * block_size);
}
block_table_vec.insert(
block_table_vec.end(), block_ids.begin(), block_ids.end());
block_cu_lens_vec.push_back(block_table_vec.size());

for (int j = 0; j < kv_len; ++j) {
const int32_t slot_base = block_ids[j / block_size];
const int32_t block_offset = j % block_size;
slot_ids.push_back(slot_base + block_offset);
}
}

// construct non-contiguous query, key and value
// generate query, key and value
torch::Tensor q = torch::rand({n_q_tokens, n_heads, head_dim}, options);
const auto n_slots = total_blocks * block_size;
torch::Tensor kv_cache = torch::rand({n_slots, head_dim}, options);

torch::Tensor q_rope =
torch::rand({n_q_tokens, n_heads, rope_head_dim}, options);
torch::Tensor k_rope_cache = torch::rand({n_slots, rope_head_dim}, options);

torch::Tensor q_cu_lens = torch::tensor(
q_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Tensor kv_cu_lens = torch::tensor(
kv_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));

torch::Tensor block_table = torch::tensor(
block_table_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Tensor block_cu_lens = torch::tensor(
block_cu_lens_vec, torch::dtype(torch::kInt32).device(torch::kCUDA));

// get combined key and value
std::vector<torch::Tensor> kvs;
kvs.reserve(slot_ids.size());
std::vector<torch::Tensor> k_ropes;
k_ropes.reserve(slot_ids.size());
for (int slot_id : slot_ids) {
// kv = kv_cache[slot_idx, :, :]
kvs.push_back(kv_cache[slot_id]);
k_ropes.push_back(k_rope_cache[slot_id]);
}
const auto kv = torch::stack(kvs, /*dim=*/0);
const auto k_rope = torch::stack(k_ropes, /*dim=*/0);

// auto ref_out = mla_varlen_ref(
// q, kv, q_rope, k_rope, sm_scale, q_cu_lens, kv_cu_lens, sm_scale);

// auto out = mla_pagedkv_sm80(q,
// kv_cache,
// q_rope,
// k_rope_cache,
// q_cu_lens,
// kv_cu_lens,
// block_table,
// block_cu_lens,
// block_size,
// max_q_len,
// sm_scale);

// EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}

INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelPagedKVTest,
::testing::Combine(::testing::Values(torch::kHalf,
torch::kBFloat16), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 8), // block_size
::testing::Values(1, 125), // max_q_len
::testing::Values(127, 1000), // max_kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm
2 changes: 1 addition & 1 deletion src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <iostream>

#include "cute/numeric/numeric_types.hpp"
#include "mla_kernel_sm80.cuh" // IWYU pragma: keep
#include "mla_kernel_sm80.cuh"
#include "mla_params.h"
#include "mla_ref.h"
#include "mla_traits_sm80.h"
Expand Down
48 changes: 46 additions & 2 deletions src/kernels/attention/mla_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@ struct MLAParamsCommon {
int head_dim = 0;
int rope_head_dim = 0;

// softmax scaling
// softmax scaling
float sm_scale = 1.0;

// used for scheduling
// TODO: remove it after persistent kernel
int max_q_len = 0;

// block size, only used for paged KV cache
int block_size = 0;

// private:
// used for performance optimization, don't change it
bool normalized = false;
float sm_scale_log2 = 0.0;
int32_t block_shift_right = 0;
int32_t block_mask = 0;

// used to initialize the params that used for performance optimization
void normalize() {
Expand All @@ -42,7 +47,19 @@ struct MLAParamsCommon {
return;
}
sm_scale_log2 = static_cast<float>(sm_scale * M_LOG2E);


// block size must be power of 2
assert(block_size > 0 && (block_size & (block_size - 1)) == 0);
auto int_log2 = [](int x) {
int n = 0;
while (x >>= 1) {
++n;
}
return n;
};
block_shift_right = int_log2(block_size);
block_mask = block_size - 1;

normalized = true;
}
};
Expand All @@ -66,4 +83,31 @@ struct MLAParams : public MLAParamsCommon {
int kv_len = 0;
};

// paged KV cache + variable length sequence
struct MLAPagedKVParams : public MLAParamsCommon {
// Q/O: (seq, head, dim): last dimension is contiguous
using Stride = cute::Stride<int64_t, int64_t /*,_1*/>;
// KV: (seq, dim): last dimension is contiguous
using KV_Stride = cute::Stride<int64_t /*,_1*/>;

Stride q_stride;
Stride q_rope_stride;

KV_Stride kv_stride;
KV_Stride k_rope_stride;

Stride o_stride;

// input shapes
// array of length batch_size + 1 holding starting offset of each sequence.
const int* __restrict__ q_cu_lens = nullptr;
const int* __restrict__ kv_cu_lens = nullptr;

// Paged KV cache
// the first slot id of each block
const int* __restrict__ block_table = nullptr;
// array of length batch_size + 1 holding starting offset of each sequence.
const int* __restrict__ block_cu_lens = nullptr;
};

} // namespace llm
Loading

0 comments on commit 40607dd

Please sign in to comment.