diff --git a/docker/Dockerfile.devel b/docker/Dockerfile.devel index 6dd9f4cb..d139f677 100644 --- a/docker/Dockerfile.devel +++ b/docker/Dockerfile.devel @@ -54,6 +54,8 @@ ARG CUDA_VERSION=12.1 COPY ./common/install_cuda.sh install_cuda.sh RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh ENV DESIRED_CUDA=${CUDA_VERSION} +ENV CUDA_HOME=/usr/local/cuda +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH RUN nvcc --version diff --git a/src/kernels/attention/mask.h b/src/kernels/attention/mask.h index 46864146..44152025 100644 --- a/src/kernels/attention/mask.h +++ b/src/kernels/attention/mask.h @@ -5,7 +5,7 @@ namespace llm { using namespace cute; -template +template struct Mask { // Fragment type for alibi slopes using FragmentT = decltype(make_tensor(Int{})); @@ -31,7 +31,7 @@ struct Mask { // cS_mn: ((2, MMA_M), (2, MMA_N)) template CUTE_HOST_DEVICE void init_alibi(IdentityS& cS_mn, - int m_block_idx, + int m_base_idx, int kv_head_idx, float sm_scale, const float* alibi_slops_ptr) { @@ -39,7 +39,7 @@ struct Mask { CUTE_UNROLL for (int i = 0; i < size<0>(cS_mn); ++i) { const auto [m, n] = cS_mn(i, _0{}); - const int q_packed_idx = m_block_idx * BLK_M + m; + const int q_packed_idx = m_base_idx + m; const int offset = q_packed_idx % group_size_; const int head_idx = kv_head_idx * group_size_ + offset; alibi_slopes_(i) = alibi_slops_ptr[head_idx] / sm_scale; @@ -50,16 +50,16 @@ struct Mask { template CUTE_HOST_DEVICE void apply(FragmentS& rS_mn, IdentityS& cS_mn, - int m_block_idx, - int n_block_idx) const { + int m_base_idx, + int n_base_idx) const { CUTE_UNROLL for (int i = 0; i < size<0>(rS_mn); ++i) { const auto alibi_slope = ALIBI ? alibi_slopes_(i) : 0.0f; CUTE_UNROLL for (int j = 0; j < size<1>(rS_mn); ++j) { auto [m, n] = cS_mn(i, j); - const int q_packed_idx = m_block_idx * BLK_M + m; - const int kv_idx = n_block_idx * BLK_N + n; + const int q_packed_idx = m_base_idx + m; + const int kv_idx = n_base_idx + n; const int q_idx = q_packed_idx / group_size_ + diagonal_offset_; diff --git a/src/kernels/attention/mha_kernel_sm80.cuh b/src/kernels/attention/mha_kernel_sm80.cuh index 4c06abfc..b60d282a 100644 --- a/src/kernels/attention/mha_kernel_sm80.cuh +++ b/src/kernels/attention/mha_kernel_sm80.cuh @@ -338,13 +338,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80( constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); using Softmax = OnlineSoftmax; - using Mask = Mask; + using Mask = Mask; Softmax softmax(sm_scale_log2); Mask mask(q_len, kv_len, group_size, sliding_window); if constexpr (ALIBI) { - mask.init_alibi( - tScS_mn, m_block_idx, kv_head_idx, sm_scale, params.alibi_slopes_ptr); + mask.init_alibi(tScS_mn, + m_block_idx * kBlockM, + kv_head_idx, + sm_scale, + params.alibi_slopes_ptr); } CUTE_NO_UNROLL @@ -376,10 +379,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80( } if (i < n_oob_mask) { - mask.apply(tSrS_mn, tScS_mn, m_block_idx, n_block_idx); + mask.apply( + tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); } else { mask.apply( - tSrS_mn, tScS_mn, m_block_idx, n_block_idx); + tSrS_mn, tScS_mn, m_block_idx * kBlockM, n_block_idx * kBlockN); } softmax.rescale(tSrS_mn, tOrO_mn); diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 42a675e9..5d92f15f 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -16,12 +16,7 @@ namespace llm { -template +template __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( __grid_constant__ const Params params) { using namespace cute; @@ -89,9 +84,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // K_ROPE: (kv_len, ROPE_HEAD_DIM) auto [KV, K_ROPE] = tile.template get_kv_tile(batch_idx); - const int q_len = size<0>(Q) / group_size; + const int q_packed_len = size<0>(Q); + const int q_len = q_packed_len / group_size; const int kv_len = size<0>(KV); - const int sliding_window = kv_len; if (m_block_idx * kBlockM >= size<0>(Q)) { // m out of bound, return @@ -180,42 +175,84 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // g2s tiled copy for q GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); + + // coordinate tensor for oob handling + // (BLK_M, BLK_K) -> (blk_m, blk_k) + Tensor cQ = make_identity_tensor(Shape<_BLK_M, _BLK_K>{}); + Tensor tCcQ = gmem_thr_copy_Q.partition_S(cQ(_, _)); + auto max_coord_Q = make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK); + auto produce_q = [&](int step) { // gQ/sQ: (BLK_M, BLK_K, STEPS) auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step)); auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step)); - cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ); + safe_copy( + gmem_tiled_copy_Q, tCgQ, tCsQ, tCcQ, max_coord_Q); }; // g2s tiled copy for q_rope GmemTiledCopyQRope gmem_tiled_copy_Q_rope; auto gmem_thr_copy_Q_rope = gmem_tiled_copy_Q_rope.get_slice(tidx); + + // (BLK_M, ROPE_HEAD_DIM) -> (blk_m, rope_head_dim) + Tensor cQ_rope = make_identity_tensor(Shape<_BLK_M, _ROPE_HEAD_DIM>{}); + Tensor tCcQ_rope = gmem_thr_copy_Q_rope.partition_S(cQ_rope); + auto produce_q_rope = [&]() { auto tCgQ_rope = gmem_thr_copy_Q_rope.partition_S(gQ_rope); auto tCsQ_rope = gmem_thr_copy_Q_rope.partition_D(sQ_rope); - cute::copy(gmem_tiled_copy_Q_rope, tCgQ_rope, tCsQ_rope); + auto max_coord = + make_coord(q_packed_len - m_block_idx * kBlockM, kRopeHeadDim); + safe_copy( + gmem_tiled_copy_Q_rope, tCgQ_rope, tCsQ_rope, tCcQ_rope, max_coord); }; // g2s tiled copy for kv GmemTiledCopyKV gmem_tiled_copy_KV; auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); + + // (BLK_N, BLK_K, STEPS) -> (blk_n, head_dim) + Tensor cKV = make_identity_tensor(Shape<_BLK_N, _BLK_K>{}); + Tensor tCcKV = gmem_thr_copy_KV.partition_S(cKV); + auto produce_kv = [&](int ni, int step, int stage) { // gKV: (BLK_N, BLK_K, n, STEPS) // sK: (BLK_N, BLK_K, STEPS, STAGES) auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step)); auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage)); - cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV); + auto max_coord = make_coord(kv_len - ni * kBlockN, kBlockK); + safe_copy( + gmem_tiled_copy_KV, tCgKV, tCsKV, tCcKV, max_coord); }; // g2s tiled copy for k_rope GmemTiledCopyKRope gmem_tiled_copy_K_rope; auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx); + + // (BLK_N, ROPE_HEAD_DIM) -> (blk_n, rope_head_dim) + Tensor cK_rope = make_identity_tensor(Shape<_BLK_N, _ROPE_HEAD_DIM>{}); + Tensor tKcK_rope = gmem_thr_copy_K_rope.partition_S(cK_rope); + auto produce_k_rope = [&](int ni, int stage) { // gK_rope: (BLK_N, ROPE_HEAD_DIM, n) // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni)); - Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); - cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope); + auto tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); + auto max_coord = make_coord(kv_len - ni * kBlockN, kRopeHeadDim); + safe_copy( + gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope, tKcK_rope, max_coord); }; // GEMM-I: S = Q@K.T @@ -382,9 +419,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx); - auto tCsO = gmem_thr_copy_O.partition_S(sO); - auto tCgO = gmem_thr_copy_O.partition_D(gO); - cute::copy(gmem_tiled_copy_O, tCsO, tCgO); + // (BLK_M, BLK_K) -> (blk_m, blk_k) + auto cO = make_identity_tensor(Shape<_BLK_M, _BLK_K>{}); + auto tCcO = gmem_thr_copy_Q.partition_S(cO); + auto max_coord_O = + make_coord(q_packed_len - m_block_idx * kBlockM, kBlockK); + + CUTE_UNROLL + for (int step = 0; step < kSteps; ++step) { + auto tCsO = gmem_thr_copy_O.partition_S(sO(_, _, step)); + auto tCgO = gmem_thr_copy_O.partition_D(gO(_, _, step)); + + safe_copy( + gmem_tiled_copy_O, tCsO, tCgO, tCcO, max_coord_O); + } }; // output accumulator: (MMA,MMA_M,MMA_K,STEPS) @@ -441,13 +492,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS); using Softmax = OnlineSoftmax; - using Mask = Mask; + using Mask = Mask; Softmax softmax(params.sm_scale_log2); - Mask mask(q_len, kv_len, group_size, sliding_window); + Mask mask(q_len, kv_len, group_size, /*sliding_window=*/kv_len); int stage = 0; - CUTE_NO_UNROLL for (int ni = n_block_min; ni < n_block_max; ++ni) { clear(tSrS); @@ -470,7 +520,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( } // apply mask + softmax - mask.apply(tSrS_mn, tScS_mn, m_block_idx, ni); + mask.apply(tSrS_mn, tScS_mn, m_block_idx * kBlockM, ni * kBlockN); softmax.rescale(tSrS_mn, tOrO_mn, reduce_rowmax); // save tSrS from rmem to smem @@ -520,12 +570,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( epilogue(tOrO); } -template +template void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { const auto batch_size = params.batch_size; const auto max_q_packed_len = params.max_q_len * params.n_heads; @@ -533,8 +578,7 @@ void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { const auto smem_size = Traits::kSmemSize; // print("smem_size: %d\n", smem_size); - auto mla_kernel = - mla_kernel_sm80; + auto mla_kernel = mla_kernel_sm80; C10_CUDA_CHECK(cudaFuncSetAttribute( mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // TODO: support persistent kernels diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index b315ac24..b0021232 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -15,21 +15,21 @@ namespace llm { #define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ [&] { \ - if (HEAD_DIM_V <= 128) { \ + if (HEAD_DIM_V == 128) { \ constexpr static int HEAD_DIM_NAME = 128; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 64; \ constexpr static int BLK_K = 128; \ constexpr static int STAGES = 2; \ return __VA_ARGS__(); \ - } else if (HEAD_DIM_V <= 256) { \ + } else if (HEAD_DIM_V == 256) { \ constexpr static int HEAD_DIM_NAME = 256; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 32; \ constexpr static int BLK_K = 128; \ constexpr static int STAGES = 2; \ return __VA_ARGS__(); \ - } else if (HEAD_DIM_V <= 512) { \ + } else if (HEAD_DIM_V == 512) { \ constexpr static int HEAD_DIM_NAME = 512; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 16; \ @@ -43,7 +43,7 @@ namespace llm { #define DISPATCH_ROPE_HEAD_DIM_(ROPE_HEAD_DIM_V, ROPE_HEAD_DIM_NAME, ...) \ [&] { \ - if (ROPE_HEAD_DIM_V <= 64) { \ + if (ROPE_HEAD_DIM_V == 64) { \ constexpr static int ROPE_HEAD_DIM_NAME = 64; \ return __VA_ARGS__(); \ } else { \ @@ -159,13 +159,13 @@ TEST_P(MLAKernelTest, MLA) { INSTANTIATE_TEST_SUITE_P( MLA, MLAKernelTest, - ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype - ::testing::Values(1, 2, 4, 10), // batch_size - ::testing::Values(64), // q_len - ::testing::Values(64, 128, 1024), // kv_len - ::testing::Values(1, 8, 128), // n_heads - ::testing::Values(128, 256, 512), // head_dim - ::testing::Values(64) // rope_head_dim + ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype + ::testing::Values(1, 2, 4, 10), // batch_size + ::testing::Values(1, 62, 125), // q_len + ::testing::Values(127, 287, 1000), // kv_len + ::testing::Values(1, 8, 128), // n_heads + ::testing::Values(128, 256, 512), // head_dim + ::testing::Values(64) // rope_head_dim )); } // namespace llm \ No newline at end of file