Skip to content

Commit

Permalink
kernel: added q and kv oob handling for MLA kernel (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Feb 27, 2025
1 parent e855f1f commit 8448c0b
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 121 deletions.
2 changes: 2 additions & 0 deletions docker/Dockerfile.devel
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ ARG CUDA_VERSION=12.1
COPY ./common/install_cuda.sh install_cuda.sh
RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh
ENV DESIRED_CUDA=${CUDA_VERSION}
ENV CUDA_HOME=/usr/local/cuda
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
RUN nvcc --version

Expand Down
78 changes: 74 additions & 4 deletions src/kernels/attention/generate_instantiation_cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator
from typing import Iterator, Any

# map from python to c++ types
DTYPE_MAP = {
Expand Down Expand Up @@ -39,6 +39,22 @@
}} // namespace llm
"""

MLA_KERNEL_TEMPLATE = """
#include "mla_kernel_sm80.cuh" // IWYU pragma: export
#include "mla_params.h" // IWYU pragma: export
#include "mla_traits_sm80.h" // IWYU pragma: export
namespace llm {{
using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>;
using Params = MLAParams;
template void launch_mla_kernel_sm80<Traits,
Params>(const Params& params,
cudaStream_t stream);
}} // namespace llm
"""


@dataclass
class MHAKernel:
Expand Down Expand Up @@ -73,10 +89,38 @@ def filename(self) -> str:
def to_str(val: bool) -> str:
return "1" if val else "0"

return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}.cu"
return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"

@dataclass
class MLAKernel:
dtype: str
head_dim: int
rope_head_dim: int
blk_m: int
blk_n: int
blk_k: int
stages: int

@property
def template(self) -> str:
assert self.head_dim % self.blk_k == 0

return MLA_KERNEL_TEMPLATE.format(
DTYPE=DTYPE_MAP[self.dtype],
HEAD_DIM=self.head_dim,
ROPE_HEAD_DIM=self.rope_head_dim,
BLK_M=self.blk_m,
BLK_N=self.blk_n,
BLK_K=self.blk_k,
STAGES=self.stages,
)

@property
def filename(self) -> str:
return f"mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_s{self.stages}_sm80.cu"


def gen_all_kernels() -> Iterator[MHAKernel]:
def gen_mha_kernels() -> Iterator[MHAKernel]:
# mha kernel instantiations
for (
dtype,
Expand Down Expand Up @@ -114,12 +158,38 @@ def gen_all_kernels() -> Iterator[MHAKernel]:
local=local,
)

def gen_mla_kernels() -> Iterator[MLAKernel]:
# TODO: choose BLK_M, BLK_N, BLK_K, STAGES based on compute capability
# mla kernel instantiations
for (
dtype,
head_dim,
rope_head_dim,
(blk_m, blk_n, blk_k, stages)
) in itertools.product(
["fp16", "bf16"], # dtype
[512], # head_dim
[64], # rope_head_dim
[(64, 16, 128, 1)], # blk_m, blk_n, blk_k, stages
):
yield MLAKernel(
dtype=dtype,
head_dim=head_dim,
rope_head_dim=rope_head_dim,
blk_m=blk_m,
blk_n=blk_n,
blk_k=blk_k,
stages=stages,
)

if __name__ == "__main__":
output_dir = Path.cwd() / "generated"
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

# written to several files to speed up compilation
for kernel in gen_all_kernels():
for kernel in gen_mha_kernels():
(output_dir / kernel.filename).write_text(kernel.template)

for kernel in gen_mla_kernels():
(output_dir / kernel.filename).write_text(kernel.template)
14 changes: 7 additions & 7 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace llm {
using namespace cute;

template <int BLK_M, int BLK_N, int ROWS_PER_THR, bool ALIBI, bool LOCAL>
template <int ROWS_PER_THR, bool ALIBI, bool LOCAL>
struct Mask {
// Fragment type for alibi slopes
using FragmentT = decltype(make_tensor<float>(Int<ROWS_PER_THR>{}));
Expand All @@ -31,15 +31,15 @@ struct Mask {
// cS_mn: ((2, MMA_M), (2, MMA_N))
template <typename IdentityS>
CUTE_HOST_DEVICE void init_alibi(IdentityS& cS_mn,
int m_block_idx,
int m_base_idx,
int kv_head_idx,
float sm_scale,
const float* alibi_slops_ptr) {
// copy alibi slopes to registers
CUTE_UNROLL
for (int i = 0; i < size<0>(cS_mn); ++i) {
const auto [m, n] = cS_mn(i, _0{});
const int q_packed_idx = m_block_idx * BLK_M + m;
const int q_packed_idx = m_base_idx + m;
const int offset = q_packed_idx % group_size_;
const int head_idx = kv_head_idx * group_size_ + offset;
alibi_slopes_(i) = alibi_slops_ptr[head_idx] / sm_scale;
Expand All @@ -50,16 +50,16 @@ struct Mask {
template <bool OOB_MASK = true, typename FragmentS, typename IdentityS>
CUTE_HOST_DEVICE void apply(FragmentS& rS_mn,
IdentityS& cS_mn,
int m_block_idx,
int n_block_idx) const {
int m_base_idx,
int n_base_idx) const {
CUTE_UNROLL
for (int i = 0; i < size<0>(rS_mn); ++i) {
const auto alibi_slope = ALIBI ? alibi_slopes_(i) : 0.0f;
CUTE_UNROLL
for (int j = 0; j < size<1>(rS_mn); ++j) {
auto [m, n] = cS_mn(i, j);
const int q_packed_idx = m_block_idx * BLK_M + m;
const int kv_idx = n_block_idx * BLK_N + n;
const int q_packed_idx = m_base_idx + m;
const int kv_idx = n_base_idx + n;

const int q_idx = q_packed_idx / group_size_ + diagonal_offset_;

Expand Down
63 changes: 47 additions & 16 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cute/tensor.hpp>

#include "cute/config.hpp"
#include "cute/container/array_aligned.hpp"
#include "cute_extensions.cuh"
#include "fast_cast.cuh"
#include "mask.h"
Expand All @@ -16,6 +17,31 @@

namespace llm {

template <typename Traits>
struct MHASharedStorage {
using DType = typename Traits::DType;
using SmemLayoutQ = typename Traits::SmemLayoutQ;
using SmemLayoutK = typename Traits::SmemLayoutK;
using SmemLayoutV = typename Traits::SmemLayoutV;
using SmemLayoutVt = typename Traits::SmemLayoutVt;
using SmemLayoutO = typename Traits::SmemLayoutO;

union {
union {
cute::array_aligned<DType, cute::cosize_v<SmemLayoutQ>> q_smem;
struct {
cute::array_aligned<DType, cute::cosize_v<SmemLayoutK>> k_smem;
union {
cute::array_aligned<DType, cute::cosize_v<SmemLayoutV>> v_smem;
cute::array_aligned<DType, cute::cosize_v<SmemLayoutVt>> vt_smem;
};
};
};

cute::array_aligned<DType, cute::cosize_v<SmemLayoutO>> o_smem;
};
};

template <typename Traits,
typename Params,
bool EVEN_K,
Expand Down Expand Up @@ -46,6 +72,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
using SmemLayoutV = typename Traits::SmemLayoutV;
using SmemLayoutVt = typename Traits::SmemLayoutVt;
using SmemLayoutO = typename Traits::SmemLayoutO;
using SharedStorage = MHASharedStorage<Traits>;

using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ;
using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV;
using GmemTiledCopyO = typename Traits::GmemTiledCopyO;
Expand Down Expand Up @@ -98,19 +126,20 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(

// Smem
extern __shared__ char smem[];
DType* q_smem = (DType*)smem;
DType* k_smem = (DType*)smem;
DType* v_smem = k_smem + cosize(SmemLayoutK{});
auto& ss = *reinterpret_cast<SharedStorage*>(smem);

// (BLK_M, HEAD_DIM), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
Tensor sQ = make_tensor(make_smem_ptr(ss.q_smem.data()), SmemLayoutQ{});
// (BLK_N, HEAD_DIM), k-major
Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{});
Tensor sK = make_tensor(make_smem_ptr(ss.k_smem.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(ss.v_smem.data()), SmemLayoutV{});

// Tensor for V^t; used in GEMM-II.
// (HEAD_DIM, BLK_N), m-major
Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{});
Tensor sVt = make_tensor(make_smem_ptr(ss.vt_smem.data()), SmemLayoutVt{});

// (BLK_M, HEAD_DIM)
Tensor sO = make_tensor(make_smem_ptr(ss.o_smem.data()), SmemLayoutO{});

// Tiled Copy
// g2s tiled copy for qkv
Expand Down Expand Up @@ -249,9 +278,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
auto tOrO = make_tensor_like<DType>(tOrAccO);
fast_cast(tOrAccO, tOrO);

// 2. copy output from reg to smem (reuse sQ)
auto sO = make_tensor(sQ.data(), SmemLayoutO{});

// 2. copy output from reg to smem
SmemTiledCopyO smem_tiled_copy_O;
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
auto taccOrO = smem_thr_copy_O.retile_S(tOrO);
Expand Down Expand Up @@ -338,13 +365,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
using Mask = Mask<kBlockM, kBlockN, kRowsPerThr, ALIBI, LOCAL>;
using Mask = Mask<kRowsPerThr, ALIBI, LOCAL>;

Softmax softmax(sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);
if constexpr (ALIBI) {
mask.init_alibi(
tScS_mn, m_block_idx, kv_head_idx, sm_scale, params.alibi_slopes_ptr);
mask.init_alibi(tScS_mn,
m_block_idx * kBlockM,
kv_head_idx,
sm_scale,
params.alibi_slopes_ptr);
}

CUTE_NO_UNROLL
Expand Down Expand Up @@ -376,10 +406,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
}

if (i < n_oob_mask) {
mask.apply(tSrS_mn, tScS_mn, m_block_idx, n_block_idx);
mask.apply(
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
} else {
mask.apply</*OOB_MASK=*/false>(
tSrS_mn, tScS_mn, m_block_idx, n_block_idx);
tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN);
}
softmax.rescale(tSrS_mn, tOrO_mn);

Expand Down Expand Up @@ -413,7 +444,7 @@ void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
const auto n_kv_heads = params.n_kv_heads;
const auto max_q_packed_len = params.max_q_len * params.group_size;

const auto smem_size = Traits::kSmemSize;
const auto smem_size = sizeof(MHASharedStorage<Traits>);
auto mha_kernel =
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
cudaFuncSetAttribute(
Expand Down
Loading

0 comments on commit 8448c0b

Please sign in to comment.