Skip to content

Commit

Permalink
kernel: use FastDivmod in attention kernels (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Mar 1, 2025
1 parent 1cc659f commit 8492879
Show file tree
Hide file tree
Showing 24 changed files with 315 additions and 539 deletions.
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Checks: >
-readability-isolate-declaration,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-avoid-const-or-ref-data-members,
-cppcoreguidelines-special-member-functions,
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-prefer-member-initializer,
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ cc_library(
NAME
attention.template
HDRS
ptx.cuh
fast_math.h
layout_convertor.h
fast_cast.cuh
online_softmax.cuh
mask.h
Expand Down Expand Up @@ -62,7 +63,6 @@ cc_test(
# mha_cpu_test.cpp
mha_traits_test.cpp
mha_kernel_sm80_test.cu
# mha_kernel_sm80_varlen_test.cu
mha_kernel_sm80_pagedkv_test.cu
DEPS
:attention.template
Expand Down
127 changes: 127 additions & 0 deletions src/kernels/attention/fast_math.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#pragma once

#include <cuda.h>

#include <cute/config.hpp>

namespace llm {

CUTE_HOST_DEVICE constexpr int clz(int x) {
for (int i = 31; i >= 0; --i) {
if ((1 << i) & x) {
return int(31 - i);
}
}
return int(32);
}

CUTE_HOST_DEVICE constexpr bool is_pow2(int x) { return (x & (x - 1)) == 0; }

CUTE_HOST_DEVICE constexpr int log2(int x) {
int a = int(31 - clz(x));
// add 1 if not a power of 2
if (!is_pow2(x)) {
a += 1;
}
return a;
}

// wrapper of PTX ex2.approx instruction, which computes 2^x
CUTE_HOST_DEVICE float exp2(float x) {
#if defined(__CUDA_ARCH__)
float y;
asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#else
return std::exp2(x);
#endif
}

// wrapper of PTX rcp.approx instruction, which computes 1/x
CUTE_HOST_DEVICE float rcp(float x) {
#if defined(__CUDA_ARCH__)
float y;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#else
return 1.0f / x;
#endif
}

// wrapper of PTX tanh.approx instruction, which computes tanh(x)
CUTE_HOST_DEVICE float tanh(float x) {
#if defined(__CUDA_ARCH__)
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#else
return std::tanh(x);
#endif
}

struct FastDivmod {
int32_t div_ = 1;
uint32_t mul_ = 0u;
uint32_t shr_ = 0u;

CUTE_HOST_DEVICE
void reset(int div) {
div_ = div;
if (div_ != 1) {
unsigned int p = 31 + log2(div_);
unsigned m =
unsigned(((1ull << p) + unsigned(div_) - 1) / unsigned(div_));

mul_ = m;
shr_ = p - 32;
}
}

constexpr FastDivmod() = default;

CUTE_HOST_DEVICE
FastDivmod(int div) { reset(div); }

CUTE_HOST_DEVICE
FastDivmod& operator=(int div) {
reset(div);
return *this;
}

CUTE_HOST_DEVICE
void divmod(int src, int& quo, int& rem) const {
quo = div(src);
rem = src - (quo * div_);
}

CUTE_HOST_DEVICE
int div(int src) const {
#if defined(__CUDA_ARCH__)
return (div_ != 1) ? __umulhi(src, mul_) >> shr_ : src;
#else
return src / div_;
#endif
}

CUTE_HOST_DEVICE
int mod(int src) const {
#if defined(__CUDA_ARCH__)
return div_ != 1 ? src - (div(src) * div_) : 0;
#else
return src % div_;
#endif
}

CUTE_HOST_DEVICE
operator int() const { return div_; }
};

// operator overloads for FastDivmod
CUTE_HOST_DEVICE int operator/(int src, const FastDivmod& d) {
return d.div(src);
}
CUTE_HOST_DEVICE int operator%(int src, const FastDivmod& d) {
return d.mod(src);
}

} // namespace llm
7 changes: 3 additions & 4 deletions src/kernels/attention/generate_instantiation_cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@
namespace llm {{
using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>;
using Params = MLAParams;
using Params = MLAPagedKVParams;
template void launch_mla_kernel_sm80<Traits,
Params>(const Params& params,
cudaStream_t stream);
template void launch_mla_kernel_sm80<Traits, Params>(const Params& params,
cudaStream_t stream);
}} // namespace llm
"""

Expand Down
39 changes: 39 additions & 0 deletions src/kernels/attention/layout_convertor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>

namespace llm {
using namespace cute;

// Convert fragment layout for different purposes
// Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN
struct LayoutConvertor {
// Convert fragment layout to rowcol layout for iterating
// (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N))
template <typename LayoutC>
CUTE_HOST_DEVICE static constexpr auto to_mn(const LayoutC& layout) {
auto l = logical_divide(layout, Shape<_2>{});
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)));
}

// (MMA=4, MMA_M, MMA_N, STEPS) => ((2, MMA_M), (2, MMA_N), STEPS)
template <typename LayoutC>
CUTE_HOST_DEVICE static constexpr auto to_mns(const LayoutC& layout) {
auto l = logical_divide(layout, Shape<_2>{});
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)),
get<3>(l));
}

// Convert fragment layout from gemm-I C to gemm-II A
// (MMA_C=4,MMA_M,MMA_N) => (MMA_A=(4, 2), MMA_M, MMA_N/2)
template <typename LayoutC>
CUTE_HOST_DEVICE static constexpr auto to_mma_a(const LayoutC& layout) {
auto l = logical_divide(layout.layout(), Shape<X, X, _2>{});
return make_layout(
make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
};

} // namespace llm
14 changes: 8 additions & 6 deletions src/kernels/attention/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <cute/config.hpp>
#include <cute/tensor.hpp>

#include "fast_math.h"

namespace llm {
using namespace cute;

Expand All @@ -10,17 +12,17 @@ struct Mask {
// Fragment type for alibi slopes
using FragmentT = decltype(make_tensor<float>(Int<ROWS_PER_THR>{}));

int q_len_;
int kv_len_;
int group_size_;
int sliding_window_;
int diagonal_offset_;
const int q_len_;
const int kv_len_;
const FastDivmod& group_size_;
const int sliding_window_;
const int diagonal_offset_;

FragmentT alibi_slopes_;

CUTE_HOST_DEVICE Mask(int q_len,
int kv_len,
int group_size,
const FastDivmod& group_size,
int sliding_window)
: q_len_(q_len),
kv_len_(kv_len),
Expand Down
27 changes: 15 additions & 12 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
#include "cute/container/array_aligned.hpp"
#include "cute_extensions.cuh"
#include "fast_cast.cuh"
#include "layout_convertor.h"
#include "mask.h"
#include "mha_tile.h"
#include "online_softmax.cuh"
#include "ptx.cuh"

namespace llm {

Expand Down Expand Up @@ -65,7 +65,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
using DType = typename Traits::DType;

using TiledMma = typename Traits::TiledMma;
using Layout = typename Traits::LayoutConvertor;

using SmemLayoutQ = typename Traits::SmemLayoutQ;
using SmemLayoutK = typename Traits::SmemLayoutK;
Expand All @@ -88,20 +87,20 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
const int kv_head_idx = blockIdx.z;
const int tidx = threadIdx.x;

MHATile<Params> tile(params);

// preprocess input parameters
const int head_dim = params.head_dim;
const int group_size = params.group_size;
const float logits_soft_cap = params.logits_soft_cap;
const float sm_scale = params.sm_scale;
const float sm_scale_log2 = params.sm_scale_log2;

const auto& group_size = params.group_size;

// ProblemShape
// (q_packed_len, HEAD_DIM)
auto [Q, O] = tile.template get_qo_tile<DType>(batch_idx, kv_head_idx);
MHATile<Params> tile(params, batch_idx, kv_head_idx);
auto [Q, O] = tile.template get_qo_tile<DType>();
// (kv_len, HEAD_DIM)
auto [K, V] = tile.template get_kv_tile<DType>(batch_idx, kv_head_idx);
auto [K, V] = tile.template get_kv_tile<DType>();

const int q_packed_len = size<0>(Q);
const int q_len = q_packed_len / group_size;
Expand Down Expand Up @@ -253,7 +252,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
fast_cast(tSrAccS, tSrS);

// convert layout from gemm-I C to gemm-II A
auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout()));
auto tOrS =
make_tensor(tSrS.data(), LayoutConvertor::to_mma_a(tSrS.layout()));

// prefetch V^t
cute::copy(
Expand Down Expand Up @@ -307,7 +307,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(

// output accumulator, (MMA,MMA_M,MMA_K)
auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mn(tOrO.layout()));
auto tOrO_mn =
make_tensor(tOrO.data(), LayoutConvertor::to_mn(tOrO.layout()));
clear(tOrO);

const int diagonal = (m_block_idx * kBlockM) / group_size + kv_len - q_len;
Expand All @@ -327,7 +328,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
if constexpr (SOFT_CAP) {
CUTE_UNROLL
for (int i = 0; i < size(tSrAccS); ++i) {
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
tSrAccS(i) = tanh(tSrAccS(i) * logits_soft_cap);
}
}
};
Expand Down Expand Up @@ -356,12 +357,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(

// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_mn(tSrS.layout()));
auto tSrS_mn =
make_tensor(tSrS.data(), LayoutConvertor::to_mn(tSrS.layout()));

// identity tensor for score accumulator
auto tScS =
thr_mma.partition_C(make_identity_tensor(Shape<_BLK_M, _BLK_N>{}));
auto tScS_mn = make_tensor(tScS.data(), Layout::to_mn(tScS.layout()));
auto tScS_mn =
make_tensor(tScS.data(), LayoutConvertor::to_mn(tScS.layout()));

constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tSrS);
using Softmax = OnlineSoftmax<kRowsPerThr>;
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ TEST_P(MHAKernelPagedKVTest, PageKV) {
block_ids.reserve(n_blocks);
for (int j = 0; j < n_blocks; ++j) {
// random assign block size
const int32_t id = absl::Uniform<int>(
absl::IntervalClosedClosed, gen, 1, total_blocks - 1);
const int32_t id =
absl::Uniform<int>(absl::IntervalClosedOpen, gen, 1, total_blocks);
// put first slot id of each block into block_table
block_ids.push_back(id * block_size);
}
Expand Down
Loading

0 comments on commit 8492879

Please sign in to comment.