Skip to content

Commit

Permalink
added test for bf16 type
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Feb 27, 2025
1 parent 14beb48 commit 9e28cb5
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions src/kernels/attention/mla_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ namespace llm {
} \
}()

#define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \
[&] { \
if (TORCH_DTYPE == torch::kHalf) { \
using TYPE_NAME = cute::half_t; \
return __VA_ARGS__(); \
} else if (TORCH_DTYPE == torch::kBFloat16) { \
using TYPE_NAME = cute::bfloat16_t; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

namespace {
torch::Tensor mla_sm80(
torch::Tensor q, // [batch, q_len, n_heads, head_dim]
Expand Down Expand Up @@ -93,17 +106,19 @@ torch::Tensor mla_sm80(
params.sm_scale = sm_scale;
params.normalize();

DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] {
using Traits = MLATraitsSM80<cute::half_t,
HEAD_DIM,
ROPE_HEAD_DIM,
BLK_M,
BLK_N,
BLK_K,
STAGES>;

launch_mla_kernel_sm80<Traits>(params, nullptr);
DISPATCH_TORCH_DTYPE_(q.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_ROPE_HEAD_DIM_(rope_head_dim, ROPE_HEAD_DIM, [&] {
using Traits = MLATraitsSM80<DTYPE,
HEAD_DIM,
ROPE_HEAD_DIM,
BLK_M,
BLK_N,
BLK_K,
STAGES>;

launch_mla_kernel_sm80<Traits>(params, nullptr);
});
});
});
return out;
Expand Down Expand Up @@ -153,19 +168,24 @@ TEST_P(MLAKernelTest, MLA) {
auto ref_out = mla_batch_ref(q, kv, q_rope, k_rope, sm_scale);
auto out = mla_sm80(q, kv, q_rope, k_rope, sm_scale);
// std::cerr << "max diff: " << (ref_out - out).abs().max() << std::endl;
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
if (dtype == torch::kBFloat16) {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2));
} else {
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3));
}
}

INSTANTIATE_TEST_SUITE_P(
MLA,
MLAKernelTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
::testing::Combine(::testing::Values(torch::kHalf,
torch::kBFloat16), // q_dtype
::testing::Values(1, 2, 4, 10), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
::testing::Values(1, 8, 128), // n_heads
::testing::Values(128, 256, 512), // head_dim
::testing::Values(64) // rope_head_dim
));

} // namespace llm

0 comments on commit 9e28cb5

Please sign in to comment.