diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 5faec0be..d1cf0149 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -17,6 +17,10 @@ cc_library( mha_traits_sm80.h mha_kernel_sm80.cuh mha_dispatch_sm80.cuh + mla_params.h + mla_tile.h + mla_traits_sm80.h + mla_kernel_sm80.cuh DEPS cutlass ) @@ -67,6 +71,19 @@ cc_test( torch ) +cc_test( + NAME + mla_kernel_test + SRCS + mla_traits_test.cpp + mla_kernel_sm80_test.cu + DEPS + :attention.template + absl::random_random + GTest::gtest_main + torch +) + nvbench_binary( NAME mha_sm80_bench @@ -86,4 +103,13 @@ nvbench_binary( :attention.template ) +nvbench_binary( + NAME + mla_sm80_bench + SRCS + mla_sm80_bench.cu + DEPS + :attention.template +) + add_subdirectory(tools) \ No newline at end of file diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index 7f490062..36d7bd7a 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -20,6 +20,12 @@ constexpr bool .with(declval()))>> = true; } // namespace detail +template +CUTE_HOST_DEVICE constexpr auto permute( + const ComposedLayout, Offset, LayoutB>& c) { + return composition(c.layout_a(), c.offset(), select(c.layout_b())); +} + template CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a, IntTupleB const& b) { diff --git a/src/kernels/attention/mha_dispatch_sm80.cuh b/src/kernels/attention/mha_dispatch_sm80.cuh index 7ba20b39..314cbb6b 100644 --- a/src/kernels/attention/mha_dispatch_sm80.cuh +++ b/src/kernels/attention/mha_dispatch_sm80.cuh @@ -46,6 +46,16 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { params.normalize(); // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 if constexpr (HEAD_DIM == 64) { using Traits = MHATraitsSM80 #include +#include "cute_extensions.cuh" namespace llm { using namespace cute; @@ -79,10 +80,8 @@ struct MHATraitsSM80 { using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - // V^T smem: (HEAD_DIM, BLK_N) row-major - using SmemLayoutVt = decltype(composition( - SmemLayoutV{}, - make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); + // V^T smem: (HEAD_DIM, BLK_N) + using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutV{})); // Thr layout for gmem copy using GmemCopyThrLayout = diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh new file mode 100644 index 00000000..09e9a077 --- /dev/null +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -0,0 +1,370 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include "cute/config.hpp" +#include "cute_extensions.cuh" +#include "fast_cast.cuh" +#include "mask.h" +#include "mla_tile.h" +#include "online_softmax.cuh" +#include "ptx.cuh" + +namespace llm { + +template +__global__ void mla_kernel_sm80(__grid_constant__ const Params params) { + using namespace cute; + + constexpr int kBlockM = Traits::kBlockM; + constexpr int kBlockN = Traits::kBlockN; + constexpr int kHeadDim = Traits::kHeadDim; + constexpr int kRopeHeadDim = Traits::kRopeHeadDim; + constexpr int kRowsPerMMA = Traits::kRowsPerMMA; + + using _BLK_M = Int; + using _BLK_N = Int; + using _HEAD_DIM = Int; + using _ROPE_HEAD_DIM = Int; + + // type alias + using DType = typename Traits::DType; + + using TiledMma = typename Traits::TiledMma; + using Layout = typename Traits::LayoutConvertor; + + using SmemLayoutQ = typename Traits::SmemLayoutQ; + using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutQRope = typename Traits::SmemLayoutQRope; + using SmemLayoutKRope = typename Traits::SmemLayoutKRope; + using SmemLayoutVt = typename Traits::SmemLayoutVt; + using SmemLayoutO = typename Traits::SmemLayoutO; + + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; + using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; + using GmemTiledCopyO = typename Traits::GmemTiledCopyO; + + using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; + using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; + using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + + const int m_block = blockIdx.x; + const int batch_idx = blockIdx.y; + const int tidx = threadIdx.x; + + const float sm_scale_log2 = params.sm_scale_log2; + + MLATile tile(params); + + // ProblemShape + // Q/O: (q_packed_len, HEAD_DIM) + // KV: (kv_len, HEAD_DIM) + // Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM) + auto [Q, Q_ROPE, O] = tile.template get_qo_tile(batch_idx); + auto [KV, K_ROPE] = tile.template get_kv_tile(batch_idx); + + const int q_packed_len = size<0>(Q); + // const int q_len = q_packed_len / group_size; + const int kv_len = size<0>(KV); + + if (m_block * kBlockM >= q_packed_len) { + // m out of bound, return + return; + } + + // Gmem + // (BLK_M, HEAD_DIM) + Tensor gQ = + local_tile(Q, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block, _0{})); + Tensor gO = + local_tile(O, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block, _0{})); + // (BLK_N, HEAD_DIM, n) + Tensor gKV = local_tile(KV, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); + + // (BLK_M, ROPE_HEAD_DIM) + Tensor gQ_rope = local_tile( + Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(m_block, _0{})); + // (BLK_N, ROPE_HEAD_DIM, n) + Tensor gK_rope = + local_tile(K_ROPE, Shape<_BLK_N, _ROPE_HEAD_DIM>{}, make_coord(_, _0{})); + + // Smem + extern __shared__ char smem[]; + DType* q_smem = (DType*)smem; + DType* kv_smem = q_smem + cosize(SmemLayoutQ{}); + DType* q_rope_smem = kv_smem + cosize(SmemLayoutKV{}); + DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{}); + + // (BLK_M, HEAD_DIM), k-major + Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); + // (BLK_N, HEAD_DIM), k-major + Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); + + // (BLK_M, ROPE_HEAD_DIM), k-major + Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{}); + // (BLK_N, ROPE_HEAD_DIM), k-major + Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{}); + + // Tensor for V^t; used in GEMM-II. + // (HEAD_DIM, BLK_N), m-major + Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); + + // Tiled Copy + // g2s tiled copy for qkv + GmemTiledCopyQ gmem_tiled_copy_Q; + GmemTiledCopyKV gmem_tiled_copy_KV; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + auto produce_q = [&]() { + auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); + auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); + cute::copy(gmem_tiled_copy_Q, tQgQ, tQsQ); + }; + + auto produce_q_rope = [&]() { + auto tQgQ_rope = gmem_thr_copy_Q.partition_S(gQ_rope); + auto tQsQ_rope = gmem_thr_copy_Q.partition_D(sQ_rope); + cute::copy(gmem_tiled_copy_Q, tQgQ_rope, tQsQ_rope); + }; + + Tensor tKsKV = gmem_thr_copy_KV.partition_D(sK); + auto produce_kv = [&](int ni) { + auto tKgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni)); + cute::copy(gmem_tiled_copy_KV, tKgKV, tKsKV); + }; + + Tensor tKsK_rope = gmem_thr_copy_KV.partition_D(sK_rope); + auto produce_k_rope = [&](int ni) { + auto tKgK_rope = gmem_thr_copy_KV.partition_S(gK_rope(_, _, ni)); + cute::copy(gmem_tiled_copy_KV, tKgK_rope, tKsK_rope); + }; + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(tidx); + // GEMM-I: S = Q@K.T + auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + auto tSrQ_rope = thr_mma.partition_fragment_A(sQ_rope); // (MMA,MMA_M,MMA_K) + auto tSrK_rope = thr_mma.partition_fragment_B(sK_rope); // (MMA,MMA_N,MMA_K) + + // s2r tiled copy for qkv + SmemTiledCopyQ smem_tiled_copy_Q; + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + auto tCsQ = smem_thr_copy_Q.partition_S(sQ); + auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); + + auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); + + SmemTiledCopyK smem_tiled_copy_K; + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + auto tCsK = smem_thr_copy_K.partition_S(sK); + auto tCrK = smem_thr_copy_K.retile_D(tSrK); + + auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); + auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); + + // S = Q@K.T + // tSrS: (MMA,MMA_M,MMA_N) + auto compute_qk = [&](auto& tSrS) { + // prefetch kv + cute::copy(smem_tiled_copy_Q, tCsQ(_, _, _0{}), tCrQ(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK(_, _, _0{}), tCrK(_, _, _0{})); + + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tSrQ); ++ki) { + // prefetch next kv + if (ki != size<2>(tSrQ) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Q, tCsQ(_, _, next_ki), tCrQ(_, _, next_ki)); + cute::copy(smem_tiled_copy_K, tCsK(_, _, next_ki), tCrK(_, _, next_ki)); + } + cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS); + } + }; + + auto compute_qk_rope = [&](auto& tSrS) { + // prefetch qk_rope + cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); + + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tSrQ_rope); ++ki) { + // prefetch next qk_rope + if (ki != size<2>(tSrQ_rope) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Q, + tCsQ_rope(_, _, next_ki), + tCrQ_rope(_, _, next_ki)); + cute::copy(smem_tiled_copy_K, + tCsK_rope(_, _, next_ki), + tCrK_rope(_, _, next_ki)); + } + cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrS); + } + }; + + // GEMM-II: O = softmax(S)@V + auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) + + SmemTiledCopyVt smem_tiled_copy_Vt; + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); + auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); + auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); + + // O = softmax(S)*V + // tSrS: (MMA,MMA_M,MMA_N) + // tOrAccO: (MMA,MMA_M,MMA_K) + auto compute_sv = [&](const auto& tSrS, auto& tOrO) { + // cast scores from Accumulator to Element + auto tSrS_ = make_tensor_like(tSrS); + fast_cast(tSrS, tSrS_); + + // convert layout from gemm-I C to gemm-II A + auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout())); + + // prefetch V^t + cute::copy(smem_tiled_copy_Vt, tCsVt(_, _, _0{}), tCrVt(_, _, _0{})); + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tOrS); ++ki) { + // prefetch next V^t + if (ki != size<2>(tOrS) - 1) { + const auto next_ki = ki + 1; + cute::copy( + smem_tiled_copy_Vt, tCsVt(_, _, next_ki), tCrVt(_, _, next_ki)); + } + cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO); + } + }; + + // tOrO: (MMA,MMA_M,MMA_K) + auto epilogue = [&](const auto& tOrO) { + // write output to gmem + // 1> cast output from ElementAccumulator to Element + auto tOrO_ = make_tensor_like(tOrO); + fast_cast(tOrO, tOrO_); + + auto sO = make_tensor(sQ.data(), SmemLayoutO{}); + // 2. copy output from reg to smem (reuse sQ) + { + SmemTiledCopyO smem_tiled_copy_O; + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + auto tCrO = smem_thr_copy_O.retile_S(tOrO_); + auto tCsO = smem_thr_copy_O.partition_D(sO); + cute::copy(smem_tiled_copy_O, tCrO, tCsO); + } + + // 3. copy output from smem to gmem + { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + auto tCsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + auto tCgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) + + // wait for smem copy done before gmem copy + __syncthreads(); + cute::copy(gmem_tiled_copy_O, tCsO, tCgO); + } + }; + + // output accumulator, (MMA,MMA_M,MMA_K) + auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); + auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_rowcol(tOrO.layout())); + clear(tOrO); + + const int n_block_min = 0; + const int n_block_max = cute::ceil_div(kv_len, kBlockN); + + // ############### Prologue ############### + // produce query: [] => [q] + produce_q(); + // produce q_rope: [q] => [q, q_rope] + produce_q_rope(); + cp_async_fence(); + // produce key: [q, q_rope] => [q, q_rope, kv] + produce_kv(0); + // produce k_rope: [q, q_rope, kv] => [q, q_rope, kv, k_rope] + produce_k_rope(0); + cp_async_fence(); + + // ############### Mainloop ############### + constexpr int kMMA_M = size<1>(tOrO); + using Softmax = OnlineSoftmax; + Softmax softmax(sm_scale_log2); + + CUTE_NO_UNROLL + for (int ni = n_block_min; ni < n_block_max; ++ni) { + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_rowcol(tSrS.layout())); + clear(tSrS); + + // wait key, queue: [q, q_rope, kv, k_rope] => [] + cp_async_wait<0>(); + __syncthreads(); + + // 1> S = Q@K.T + compute_qk(tSrS); + + // 2> S += Q_rope@K_rope.T + compute_qk_rope(tSrS); + + softmax.rescale(tSrS_mn, tOrO_mn); + + // 3> O = softmax(S)*V + compute_sv(tSrS, tOrO); + + // produce next key: [] => [kv, k_rope] + if (ni != n_block_max - 1) { + produce_kv(ni + 1); + produce_k_rope(ni + 1); + } + cp_async_fence(); + } + + // ############### Epilogue ############### + + // normalize output: o /= rowsum + softmax.finalize(tOrO_mn); + + // write output to gmem + epilogue(tOrO); +} + +template +void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { + const auto batch_size = params.batch_size; + const auto max_q_packed_len = params.max_q_len * params.n_heads; + + const auto smem_size = Traits::kSmemSize; + + auto mla_kernel = + mla_kernel_sm80; + C10_CUDA_CHECK(cudaFuncSetAttribute( + mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // TODO: support persistent kernels + dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); + dim3 block = Traits::kThreadNum; + mla_kernel<<>>(params); +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu new file mode 100644 index 00000000..fcec6024 --- /dev/null +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -0,0 +1,127 @@ +#include +#include + +#include +#include +#include + +#include "cute/numeric/numeric_types.hpp" +#include "mla_kernel_sm80.cuh" // IWYU pragma: keep +#include "mla_params.h" +#include "mla_ref.h" +#include "mla_traits_sm80.h" + +namespace llm { + +namespace { +torch::Tensor mla_sm80( + torch::Tensor q, // [batch, q_len, n_heads, head_dim] + torch::Tensor kv, // [batch, kv_len, head_dim] + torch::Tensor q_rope, // [batch, q_len, n_heads, rope_head_dim] + torch::Tensor k_rope, // [batch, kv_len, rope_head_dim] + float sm_scale) { + const auto batch_size = q.size(0); + const auto q_len = q.size(-3); + const auto kv_len = kv.size(-3); + const auto n_heads = q.size(-2); + const auto head_dim = q.size(-1); + const auto rope_head_dim = q_rope.size(-1); + + auto out = torch::empty_like(q); + + // construct attention params + MLAParams params; + params.q_ptr = q.const_data_ptr(); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.kv_ptr = kv.const_data_ptr(); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + + params.q_rope_ptr = q_rope.const_data_ptr(); + params.q_rope_stride = + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + params.k_rope_ptr = k_rope.const_data_ptr(); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + params.batch_size = batch_size; + params.max_q_len = q_len; + params.n_heads = n_heads; + params.q_len = q_len; + params.kv_len = kv_len; + params.head_dim = head_dim; + params.rope_head_dim = rope_head_dim; + params.sm_scale = sm_scale; + params.normalize(); + + using Traits = MLATraitsSM80; + launch_mla_kernel_sm80(params, nullptr); + return out; +} + +} // namespace + +class MLAKernelTest + : public ::testing::TestWithParam> { + public: + void SetUp() override { + // Set random seed for test stability + torch::manual_seed(0); + } +}; + +TEST_P(MLAKernelTest, MLA) { + const auto [dtype, + batch_size, + q_len, + kv_len, + n_heads, + head_dim, + rope_head_dim] = GetParam(); + // const auto head_dim = kv_lora_rank + rope_head_dim; + const auto options = torch::dtype(dtype).device(torch::kCUDA); + + // q: [batch, len, n_heads, head_dim] + // kv: [batch, len, head_dim] + const auto q = torch::randn({batch_size, q_len, n_heads, head_dim}, options); + const auto kv = torch::randn({batch_size, kv_len, head_dim}, options); + + // q_rope: [batch, len, n_heads, rope_head_dim] + // kv_rope: [batch, len, rope_head_dim] + const auto q_rope = + torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options); + const auto k_rope = + torch::randn({batch_size, kv_len, rope_head_dim}, options); + + const float sm_scale = 1.0 / sqrt(head_dim + rope_head_dim); + + auto ref_out = mla_batch_ref(q, kv, q_rope, k_rope, sm_scale); + auto out = mla_sm80(q, kv, q_rope, k_rope, sm_scale); + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); +} + +INSTANTIATE_TEST_SUITE_P( + MLA, + MLAKernelTest, + ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype + ::testing::Values(1), // batch_size + ::testing::Values(64), // q_len + ::testing::Values(64), // kv_len + ::testing::Values(8), // n_heads + ::testing::Values(64), // head_dim + ::testing::Values(64) // rope_head_dim + )); + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h new file mode 100644 index 00000000..1e3415eb --- /dev/null +++ b/src/kernels/attention/mla_params.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +#include "cute/layout.hpp" +namespace llm { + +// common params for attention kernels +struct MLAParamsCommon { + const void* __restrict__ q_ptr = nullptr; + const void* __restrict__ kv_ptr = nullptr; + + const void* __restrict__ q_rope_ptr = nullptr; + const void* __restrict__ k_rope_ptr = nullptr; + + void* __restrict__ o_ptr = nullptr; + + // input shapes + int batch_size = 0; + + int n_heads = 0; + int head_dim = 0; + int rope_head_dim = 0; + + // softmax scaling + float sm_scale = 1.0; + + // used for scheduling + // TODO: remove it after persistent kernel + int max_q_len = 0; + + // private: + // used for performance optimization, don't change it + bool normalized = false; + float sm_scale_log2 = 0.0; + + // used to initialize the params that used for performance optimization + void normalize() { + if (normalized) { + // already normalized + return; + } + sm_scale_log2 = static_cast(sm_scale * M_LOG2E); + + normalized = true; + } +}; + +struct MLAParams : public MLAParamsCommon { + // Q/O: (batch, seq, head, dim): last dimension is contiguous + using Stride = cute::Stride; + // KV: (batch, seq, dim): last dimension is contiguous + using KV_Stride = cute::Stride; + + Stride q_stride; + Stride q_rope_stride; + + KV_Stride kv_stride; + KV_Stride k_rope_stride; + + Stride o_stride; + + // input shapes + int q_len = 0; + int kv_len = 0; +}; + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h new file mode 100644 index 00000000..abe84113 --- /dev/null +++ b/src/kernels/attention/mla_ref.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace llm { +// Multi-head latent attention implementation using pytorch +// reference implementation: +// https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L477 +inline torch::Tensor mla_batch_ref( + torch::Tensor q, // [batch, q_len, n_heads, head_dim] + torch::Tensor kv, // [batch, kv_len, head_dim] + torch::Tensor q_rope, // [batch, q_len, n_heads, rope_head_dim] + torch::Tensor k_rope, // [batch, kv_len, rope_head_dim] + float sm_scale) { + const auto q_len = q.size(-3); + const auto n_heads = q.size(-2); + const auto kv_len = kv.size(-2); + const auto kv_lora_rank = kv.size(-1); + const auto qk_rope_head_dim = q_rope.size(-1); + assert(kv_len >= q_len); + + // use float32 for better precision + auto q_ = q.to(torch::kFloat); + auto kv_ = kv.to(torch::kFloat); + auto q_rope_ = q_rope.to(torch::kFloat); + auto k_rope_ = k_rope.to(torch::kFloat); + + // query * key => [batch, q_len, n_heads, kv_len] + auto scores = torch::einsum("bqhr,bkr->bqhk", {q_, kv_}) + + torch::einsum("bqhp,bkp->bqhk", {q_rope_, k_rope_}); + // apply scale + scores *= sm_scale; + + // safe softmax + scores = torch::softmax(scores, /*dim=*/-1); + + // score * value => [batch_size, q_len, n_heads, kv_lora_rank] + return torch::einsum("bqhk,bkr->bqhr", {scores, kv_}).type_as(q); +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu new file mode 100644 index 00000000..75be7925 --- /dev/null +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -0,0 +1,109 @@ +#include +#include + +#include +#include + +#include "mla_kernel_sm80.cuh" // IWYU pragma: keep +#include "mla_params.h" +#include "mla_traits_sm80.h" + +using namespace llm; + +#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ + [&] { \ + if (HEAD_DIM_V <= 64) { \ + constexpr static int HEAD_DIM_NAME = 64; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 128) { \ + constexpr static int HEAD_DIM_NAME = 128; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 256) { \ + constexpr static int HEAD_DIM_NAME = 256; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 384) { \ + constexpr static int HEAD_DIM_NAME = 384; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 512) { \ + constexpr static int HEAD_DIM_NAME = 512; \ + constexpr static int BLK_N = 32; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } \ + }() + +void mla_bench_sm80(nvbench::state& state) { + // Collect CUPTI metrics + state.collect_cupti_metrics(); + + // Get the parameters + const auto batch_size = state.get_int64("batch_size"); + const auto q_len = state.get_int64("q_len"); + const auto kv_len = state.get_int64("kv_len"); + const auto n_heads = state.get_int64("n_heads"); + const auto head_dim = state.get_int64("head_dim"); + const auto rope_head_dim = state.get_int64("rope_head_dim"); + + const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); + const auto q = torch::randn({batch_size, q_len, n_heads, head_dim}, options); + const auto kv = torch::randn({batch_size, kv_len, head_dim}, options); + + const auto q_rope = + torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options); + const auto k_rope = + torch::randn({batch_size, kv_len, rope_head_dim}, options); + + auto out = torch::empty_like(q); + + // construct attention params + MLAParams params; + params.q_ptr = q.const_data_ptr(); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.kv_ptr = kv.const_data_ptr(); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + + params.q_rope_ptr = q_rope.const_data_ptr(); + params.q_rope_stride = + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + params.k_rope_ptr = k_rope.const_data_ptr(); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + params.batch_size = batch_size; + params.max_q_len = q_len; + params.n_heads = n_heads; + params.q_len = q_len; + params.kv_len = kv_len; + params.head_dim = head_dim; + params.rope_head_dim = rope_head_dim; + params.sm_scale = 1.0; + params.normalize(); + + state.exec([&](nvbench::launch& launch) { + DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + using Traits = MLATraitsSM80; + + launch_mla_kernel_sm80(params, launch.get_stream()); + }); + }); +} + +NVBENCH_BENCH(mla_bench_sm80) + .add_int64_axis("batch_size", {1}) + .add_int64_axis("q_len", {1024}) + .add_int64_axis("kv_len", {1024}) + .add_int64_axis("n_heads", {8}) + .add_int64_axis("head_dim", {256}) + .add_int64_axis("rope_head_dim", {64}); diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h new file mode 100644 index 00000000..b6b323c1 --- /dev/null +++ b/src/kernels/attention/mla_tile.h @@ -0,0 +1,72 @@ +#pragma once +#include +#include + +#include "gather_tensor.hpp" +#include "mla_params.h" + +namespace llm { +using namespace cute; + +template +struct MLATile { + static_assert(cute::dependent_false, "not implemented"); +}; + +// AttentionTile specialization for AttentionParams +template <> +struct MLATile { + // NOLINTNEXTLINE + const MLAParams& params_; + + CUTE_HOST_DEVICE MLATile(const MLAParams& params) : params_(params) {} + + // return the query/output tile: (q_packed_len, head_dim) + // return q_rope tile: (q_packed_len, qk_rope_head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx) const { + // (batch, seq, head, dim) + const auto q_packed_len = params_.q_len * params_.n_heads; + const auto q_offset = batch_idx * get<0>(params_.q_stride); + auto q = + make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(q_packed_len, params_.head_dim), + make_stride(get<2>(params_.q_stride), _1{})); + + // (batch, seq, head, rope_head_dim) + const auto q_rope_offset = batch_idx * get<0>(params_.q_rope_stride); + auto q_rope = make_tensor( + make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), + make_shape(q_packed_len, params_.rope_head_dim), + make_stride(get<2>(params_.q_rope_stride), _1{})); + + // (batch, seq, head, dim) + const auto o_offset = batch_idx * get<0>(params_.o_stride); + auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(q_packed_len, params_.head_dim), + make_stride(get<2>(params_.o_stride), _1{})); + return make_tuple(q, q_rope, o); + } + + // return the key/value tile: (kv_len, head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile(int batch_idx) const { + // (batch, seq, dim) + const auto kv_offset = batch_idx * get<0>(params_.kv_stride); + // k[batch_idx, :, kv_head_idx, :] + auto kv = + make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), + make_shape(params_.kv_len, params_.head_dim), + make_stride(get<1>(params_.kv_stride), _1{})); + + // (batch, seq, rope_head_dim) + const auto k_rope_offset = batch_idx * get<0>(params_.k_rope_stride); + auto k_rope = make_tensor( + make_gmem_ptr((const Element*)params_.k_rope_ptr + k_rope_offset), + make_shape(params_.kv_len, params_.rope_head_dim), + make_stride(get<1>(params_.k_rope_stride), _1{})); + return make_tuple(kv, k_rope); + } +}; + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h new file mode 100644 index 00000000..31d2a776 --- /dev/null +++ b/src/kernels/attention/mla_traits_sm80.h @@ -0,0 +1,159 @@ +#pragma once +#include +#include + +#include "cute_extensions.cuh" + +namespace llm { +using namespace cute; + +namespace detail { + +// Convert fragment layout for different purposes +// Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN +struct LayoutConvertor { + // Convert fragment layout to rowcol layout for iterating + // (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N)) + template + CUTE_HOST_DEVICE static constexpr auto to_rowcol(const LayoutC& layout) { + auto l = logical_divide(layout, Shape<_2>{}); + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<2>(l))); + } + + // Convert fragment layout from gemm-I C to gemm-II A + // (MMA_C=4,MMA_M,MMA_N) => (MMA_A=(4, 2), MMA_M, MMA_N/2) + template + CUTE_HOST_DEVICE static constexpr auto to_mma_a(const LayoutC& layout) { + auto l = logical_divide(layout.layout(), Shape{}); + return make_layout( + make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +} // namespace detail + +template +struct MLATraitsSM80 { + // helpful aliases + static constexpr int kHeadDim = HEAD_DIM; + static constexpr int kRopeHeadDim = ROPE_HEAD_DIM; + static constexpr int kBlockM = BLK_M; + static constexpr int kBlockN = BLK_N; + static constexpr int kBlockK = BLK_K; + static constexpr int kRowsPerMMA = 2; + + static_assert(kHeadDim % kBlockK == 0); + static_assert(kRopeHeadDim % kBlockK == 0); + + using DType = DTYPE; + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + using _HEAD_DIM = Int; + using _ROPE_HEAD_DIM = Int; + + // ******* Mainloop ******* + // TiledMMA (64x16x16) for gemm-I and gemm-II + // choose MMA_Atom based on Element type + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + using TiledMma = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + + // Layout convertor for TiledMMA (64x16x16) + using LayoutConvertor = detail::LayoutConvertor; + + // SMEM layout for QKV + // Atom layout: (8, BLK_K):(BLK_K, 1) k-major + using SmemLayoutAtom = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_BLK_K, _1>>{})); + + // Q smem: (BLK_M, HEAD_DIM) + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); + + using SmemLayoutQRope = + decltype(tile_to_shape(SmemLayoutAtom{}, + Shape<_BLK_M, _ROPE_HEAD_DIM>{})); + + // KV smem: (BLK_N, HEAD_DIM) + using SmemLayoutKV = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); + + using SmemLayoutKRope = + decltype(tile_to_shape(SmemLayoutAtom{}, + Shape<_BLK_N, _ROPE_HEAD_DIM>{})); + + // V^T smem: (HEAD_DIM, BLK_N) + using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutKV{})); + + // Thr layout for gmem copy + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + + // Tiled copy for QKV + // g2s tiled copy for q + using GmemTiledCopyQ = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // g2s tiled copy for kv + using GmemTiledCopyKV = GmemTiledCopyQ; + + // s2r tiled copy for gemm-I + using SmemTiledCopyQ = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyK = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // s2r tiled copy for gemm-II + using SmemTiledCopyVt = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // ******* Epilogue ******* + + // O smem: (BLK_M, HEAD_DIM):(K, 1), k-major + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); + + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + + // s2g tiled copy for O + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // r2s tiled copy for O + using SmemTiledCopyO = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma{})); + + // constexpr values for kernel launch + static constexpr size_t kSmemSize = + sizeof(DType) * (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) + + cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{})); + + static constexpr size_t kThreadNum = size(TiledMma{}); +}; + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp new file mode 100644 index 00000000..44f75614 --- /dev/null +++ b/src/kernels/attention/mla_traits_test.cpp @@ -0,0 +1,66 @@ +#include + +#include + +#include "cute_extensions.cuh" +#include "gather_tensor.hpp" +#include "mla_traits_sm80.h" + +namespace llm { + +using namespace cute; + +template +void test_mla_traits() { + // type alias + using TiledMma = typename Traits::TiledMma; + using Layout = typename Traits::LayoutConvertor; + + using SmemLayoutQ = typename Traits::SmemLayoutQ; + using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutQRope = typename Traits::SmemLayoutQRope; + using SmemLayoutKRope = typename Traits::SmemLayoutKRope; + using SmemLayoutVt = typename Traits::SmemLayoutVt; + using SmemLayoutO = typename Traits::SmemLayoutO; + + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; + using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; + using GmemTiledCopyO = typename Traits::GmemTiledCopyO; + + using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; + using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; + using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + + // test layout conversation + Tensor sQ = make_tensor(counting_iterator(0), SmemLayoutQ{}); + Tensor sKV = make_tensor(counting_iterator(0), SmemLayoutKV{}); + Tensor sVt = make_tensor(sKV.data(), SmemLayoutVt{}); + + Tensor sQ_rope = make_tensor(counting_iterator(0), SmemLayoutQRope{}); + Tensor sKV_rope = make_tensor(counting_iterator(0), SmemLayoutKRope{}); + + // print("sQ:"); print(sQ);print("\n"); + // print("sKV:"); print(sKV);print("\n"); + + // print("sQ_rope:"); print(sQ_rope);print("\n"); + // print("sKV_rope:"); print(sKV_rope);print("\n"); + + // print("sVt:"); print(sVt);print("\n"); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(0); + auto tOrVt = thr_mma.partition_fragment_B(sVt); + // TODO: add tests for layout conformance +} + +TEST(MLATraitsTest, TraitsSM80) { + test_mla_traits>(); +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/tools/CMakeLists.txt b/src/kernels/attention/tools/CMakeLists.txt index d6b843a6..64032d65 100644 --- a/src/kernels/attention/tools/CMakeLists.txt +++ b/src/kernels/attention/tools/CMakeLists.txt @@ -6,6 +6,7 @@ cc_binary( SRCS mha_traits_viewer.cpp DEPS + :common cutlass absl::strings absl::str_format diff --git a/src/kernels/attention/tools/mha_traits_viewer.cpp b/src/kernels/attention/tools/mha_traits_viewer.cpp index e5aec6ac..c42fe1e0 100644 --- a/src/kernels/attention/tools/mha_traits_viewer.cpp +++ b/src/kernels/attention/tools/mha_traits_viewer.cpp @@ -2,6 +2,7 @@ #include #include "../mha_traits_sm80.h" +#include "common/pretty_print.h" #include "print_svg.hpp" using namespace cute; @@ -33,6 +34,10 @@ void print_attn_traits() { using SmemTiledCopyK = typename Traits::SmemTiledCopyK; using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + // print dynamic smem size + print("Dynamic Smem Size: "); + print(readable_size(Traits::kSmemSize).c_str()); + print("\n"); // print tiled mma print("TiledMma: \n");