Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernel: revert experimental TiledMMA separation change. #401

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 29 additions & 69 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// type alias
using DType = typename Traits::DType;

using TiledMma_QK = typename Traits::TiledMma_QK;
using TiledMma_PV = typename Traits::TiledMma_PV;
using TiledMma = typename Traits::TiledMma;
using Layout = typename Traits::LayoutConvertor;

using SmemLayoutQ = typename Traits::SmemLayoutQ;
using SmemLayoutKV = typename Traits::SmemLayoutKV;
using SmemLayoutP = typename Traits::SmemLayoutP;
using SmemLayoutQRope = typename Traits::SmemLayoutQRope;
using SmemLayoutKRope = typename Traits::SmemLayoutKRope;
using SmemLayoutVt = typename Traits::SmemLayoutVt;
Expand All @@ -63,8 +61,6 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ;
using SmemTiledCopyK = typename Traits::SmemTiledCopyK;
using SmemTiledCopyS = typename Traits::SmemTiledCopyS;
using SmemTiledCopyP = typename Traits::SmemTiledCopyP;
using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt;
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;

Expand Down Expand Up @@ -106,18 +102,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
extern __shared__ char smem[];
DType* q_smem = (DType*)smem;
DType* kv_smem = q_smem + cosize(SmemLayoutQ{});
DType* p_smem = kv_smem + cosize(SmemLayoutKV{});
DType* q_rope_smem = p_smem + cosize(SmemLayoutP{});
DType* q_rope_smem = kv_smem + cosize(SmemLayoutKV{});
DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{});

// (BLK_M, BLK_K, STAGES), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
// (BLK_N, BLK_K, STAGES), k-major
Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{});

// (BLK_M, BLK_N), k-major
Tensor sP = make_tensor(make_smem_ptr(p_smem), SmemLayoutP{});

// (BLK_M, ROPE_HEAD_DIM), k-major
Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{});
// (BLK_N, ROPE_HEAD_DIM), k-major
Expand All @@ -142,14 +134,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto tGCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, stage));
auto tGCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, stage));
cute::copy(gmem_tiled_copy_Q, tGCgQ, tGCsQ);
cp_async_fence();
};

auto produce_q_rope = [&]() {
auto tQgQ_rope = gmem_thr_copy_Q.partition_S(gQ_rope);
auto tQsQ_rope = gmem_thr_copy_Q.partition_D(sQ_rope);
cute::copy(gmem_tiled_copy_Q, tQgQ_rope, tQsQ_rope);
cp_async_fence();
};

// (CPY, CPY_N, CPY_K, STAGES)
Expand All @@ -159,29 +149,27 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// sK: (BLK_N, BLK_K, STAGES)
Tensor tGCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, stage));
cute::copy(gmem_tiled_copy_KV, tGCgKV, tGCsKV);
cp_async_fence();
};

Tensor tKsK_rope = gmem_thr_copy_KV.partition_D(sK_rope);
auto produce_k_rope = [&](int ni) {
auto tKgK_rope = gmem_thr_copy_KV.partition_S(gK_rope(_, _, ni));
cute::copy(gmem_tiled_copy_KV, tKgK_rope, tKsK_rope);
cp_async_fence();
};

TiledMma_QK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_slice(tidx);
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tidx);
// GEMM-I: S = Q@K.T
// sQ/sK: (BLK_M, BLK_K, STAGES)
auto tSrQ = partition_fragment_A(
thr_mma_qk, sQ(_, _, _0{}), _, _2{}); // (MMA, MMA_M, _2)
thr_mma, sQ(_, _, _0{}), _, _2{}); // (MMA, MMA_M, _2)
auto tSrK = partition_fragment_B(
thr_mma_qk, sK(_, _, _0{}), _, _2{}); // (MMA, MMA_N, _2)
thr_mma, sK(_, _, _0{}), _, _2{}); // (MMA, MMA_N, _2)

auto tSrQ_rope =
partition_fragment_A(thr_mma_qk, sQ_rope, _, _2{}); // (MMA, MMA_M, _2)
partition_fragment_A(thr_mma, sQ_rope, _, _2{}); // (MMA, MMA_M, _2)
auto tSrK_rope =
partition_fragment_B(thr_mma_qk, sK_rope, _, _2{}); // (MMA, MMA_N, _2)
partition_fragment_B(thr_mma, sK_rope, _, _2{}); // (MMA, MMA_N, _2)

// s2r tiled copy for qkv
SmemTiledCopyQ smem_tiled_copy_Q;
Expand Down Expand Up @@ -228,7 +216,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(
smem_tiled_copy_K, tCsK_s(_, _, next_k), tCrK(_, _, (next_k & 1)));
}
cute::gemm(tiled_mma_qk, tSrQ(_, _, (k & 1)), tSrK(_, _, (k & 1)), tSrS);
cute::gemm(tiled_mma, tSrQ(_, _, (k & 1)), tSrK(_, _, (k & 1)), tSrS);
}
};

Expand All @@ -248,29 +236,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
tCsK_rope(_, _, next_k),
tCrK_rope(_, _, (next_k & 1)));
}
cute::gemm(tiled_mma_qk,
tSrQ_rope(_, _, (k & 1)),
tSrK_rope(_, _, (k & 1)),
tSrS);
cute::gemm(
tiled_mma, tSrQ_rope(_, _, (k & 1)), tSrK_rope(_, _, (k & 1)), tSrS);
}
};

// GEMM-II: O = softmax(S)@V
TiledMma_PV tiled_mma_pv;
auto thr_mma_pv = tiled_mma_pv.get_slice(tidx);
// sS: (BLK_M, BLK_N)
// (MMA, MMA_M, _2)
auto tOrP = partition_fragment_A(thr_mma_pv, sP, _, _2{});
// sVt: (BLK_K, BLK_N, STAGES)
// (MMA, MMA_N, _2)
auto tOrVt = partition_fragment_B(thr_mma_pv, sVt(_, _, _0{}), _, _2{});

SmemTiledCopyP smem_tiled_copy_P;
auto smem_thr_copy_P = smem_tiled_copy_P.get_slice(tidx);
// (CPY, CPY_M, CPY_K)
auto tCsP = smem_thr_copy_P.partition_S(sP);
// (CPY, CPY_M, _2)
auto tCrP = smem_thr_copy_P.retile_D(tOrP);
auto tOrVt = partition_fragment_B(thr_mma, sVt(_, _, _0{}), _, _2{});

SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx);
Expand All @@ -281,44 +254,27 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

// O = softmax(S)*V
// tOrS: (MMA,MMA_M,MMA_K)
auto compute_pv = [&](auto& tOrO, int s) {
auto compute_pv = [&](const auto& tOrS, auto& tOrO, int s) {
// (MMA,MMA_M,MMA_N, STAGES)
auto tOrO_s = tOrO(_, _, _, s);

// (CPY, CPY_N, CPY_K, STAGES)
auto tCsVt_s = tCsVt(_, _, _, s);
// tCsVt_s: (CPY, CPY_N, CPY_K) => tCrVt: (CPY, CPY_N, _2)
cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{}));
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{}));

CUTE_UNROLL
for (int k = 0; k < size<2>(tCsVt_s); ++k) {
if (k != size<2>(tCsVt_s) - 1) {
for (int k = 0; k < size<2>(tOrS); ++k) {
if (k != size<2>(tOrS) - 1) {
const auto next_k = k + 1;
cute::copy(
smem_tiled_copy_P, tCsP(_, _, next_k), tCrP(_, _, (next_k & 1)));
cute::copy(smem_tiled_copy_Vt,
tCsVt_s(_, _, next_k),
tCrVt(_, _, (next_k & 1)));
}
cute::gemm(
tiled_mma_pv, tOrP(_, _, (k & 1)), tOrVt(_, _, (k & 1)), tOrO_s);
cute::gemm(tiled_mma, tOrS(_, _, k), tOrVt(_, _, (k & 1)), tOrO_s);
}
};

SmemTiledCopyS smem_tiled_copy_S;
auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(tidx);

auto save_scores = [&](const auto& tSrS) {
// cast Accumulator to Element type
auto tSrS_ = make_tensor_like<DType>(tSrS);
fast_cast(tSrS, tSrS_);
// copy scores from rmem to smem
auto tCrS = smem_thr_copy_S.retile_S(tSrS_);
auto tCsS = smem_thr_copy_S.partition_D(sP);
cute::copy(smem_tiled_copy_S, tCrS, tCsS);
};

// tOrO: (MMA,MMA_M,MMA_K,STAGES)
auto epilogue = [&](const auto& tOrO) {
// write output to gmem
Expand Down Expand Up @@ -353,8 +309,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
};

// output accumulator: (MMA, MMA_M, MMA_K, STAGES)
auto tOrO =
partition_fragment_C(thr_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{});
auto tOrO = partition_fragment_C(thr_mma, Shape<_BLK_M, _BLK_K, _STAGES>{});
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout()));
clear(tOrO);

Expand All @@ -371,9 +326,11 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

// produce k_rope: [q_rope, q...] => [q_rope, q..., k_rope, kv...]
produce_k_rope(0);
cp_async_fence();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
produce_kv(0, s);
cp_async_fence();
}

// ############### Mainloop ###############
Expand All @@ -384,7 +341,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrS = partition_fragment_C(tiled_mma_qk, Shape<_BLK_M, _BLK_N>{});
auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_mn(tSrS.layout()));
clear(tSrS);

Expand All @@ -408,23 +365,26 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(

softmax.rescale(tSrS_mn, tOrO_mn);

// save tSrS from rmem to smem
save_scores(tSrS);
__syncthreads();

// 3> O = softmax(S)*V
// cast scores from Accumulator to Element
auto tSrS_ = make_tensor_like<DType>(tSrS);
fast_cast(tSrS, tSrS_);
// convert layout from gemm-I C to gemm-II A
auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout()));
const auto next_ni = ni + 1;
if (next_ni != n_block_max) {
produce_k_rope(next_ni);
cp_async_fence();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
compute_pv(tOrS, tOrO, s);
produce_kv(next_ni, s);
cp_async_fence();
}
} else {
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
compute_pv(tOrS, tOrO, s);
}
}
}
Expand Down
56 changes: 19 additions & 37 deletions src/kernels/attention/mla_traits_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,9 @@ struct MLATraitsSM80 {
std::conditional_t<std::is_same_v<DType, cute::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>>;
using TiledMma_QK = TiledMMA<MMA_Atom_,
Layout<Shape<_4, _1, _1>>, // warp layout 4x1x1
Tile<_64, _16, _16>>; // Prom Shape 64x16x16

using TiledMma_PV = TiledMMA<MMA_Atom_,
Layout<Shape<_4, _1, _1>>, // warp layout 4x1x1
Tile<_64, _16, _16>>; // Prom Shape 64x16x16

// use 128-bit vectorizing copy
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;
using TiledMma = TiledMMA<MMA_Atom_,
Layout<Shape<_4, _1, _1>>, // warp layout 4x1x1
Tile<_64, _16, _16>>; // Prom Shape 64x16x16

// Layout convertor for TiledMMA (64x16x16)
using LayoutConvertor = detail::LayoutConvertor;
Expand All @@ -103,10 +96,6 @@ struct MLATraitsSM80 {
decltype(tile_to_shape(SmemLayoutAtom{},
Shape<_BLK_N, _BLK_K, _STAGES>{}));

// P smem: (BLK_M, BLK_N)
using SmemLayoutP =
decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _BLK_N>{}));

// V^T smem: (BLK_K, BLK_N, STAGES)
using SmemLayoutVt = decltype(permute<1, 0, 2>(SmemLayoutKV{}));

Expand Down Expand Up @@ -137,53 +126,46 @@ struct MLATraitsSM80 {
// g2s tiled copy for kv
using GmemTiledCopyKV = GmemTiledCopyQ;

// s2r tiled copy for gemm-I S = Q*K^T
// s2r tiled copy for gemm-I
using SmemTiledCopyQ =
decltype(make_tiled_copy_A(Copy_Atom<SM75_U32x4_LDSM_N, DType>{},
TiledMma_QK{}));
TiledMma{}));
using SmemTiledCopyK =
decltype(make_tiled_copy_B(Copy_Atom<SM75_U32x4_LDSM_N, DType>{},
TiledMma_QK{}));
TiledMma{}));

// r2s tiled copy for gemm-I S
using SmemTiledCopyS =
decltype(make_tiled_copy_C(Copy_Atom<VectorizingCopy, DType>{},
TiledMma_QK{}));

// s2r tiled copy for gemm-II: O = P*V^T
using SmemTiledCopyP =
decltype(make_tiled_copy_A(Copy_Atom<SM75_U32x4_LDSM_N, DType>{},
TiledMma_PV{}));
// s2r tiled copy for gemm-II
using SmemTiledCopyVt =
decltype(make_tiled_copy_B(Copy_Atom<SM75_U16x8_LDSM_T, DType>{},
TiledMma_PV{}));

// r2s tiled copy for gemm-II O
using SmemTiledCopyO =
decltype(make_tiled_copy_C(Copy_Atom<VectorizingCopy, DType>{},
TiledMma_PV{}));
TiledMma{}));

// ******* Epilogue *******

// O smem: (BLK_M, BLK_K, STAGES) k-major
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtom{},
Shape<_BLK_M, _BLK_K, _STAGES>{}));

// use 128-bit vectorizing copy
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;

// s2g tiled copy for O
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<VectorizingCopy, DType>{},
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
));

// r2s tiled copy for O
using SmemTiledCopyO =
decltype(make_tiled_copy_C(Copy_Atom<VectorizingCopy, DType>{},
TiledMma{}));

// constexpr values for kernel launch
static constexpr size_t kSmemSize =
sizeof(DType) *
(cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) + cosize(SmemLayoutP{}) +
cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{}));
sizeof(DType) * (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) +
cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{}));

static constexpr size_t kThreadNum =
std::max(size(TiledMma_QK{}), size(TiledMma_PV{}));
static constexpr size_t kThreadNum = size(TiledMma{});
};

} // namespace llm
8 changes: 4 additions & 4 deletions src/kernels/attention/mla_traits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace cute;
template <typename Traits>
void test_mla_traits() {
// type alias
using TiledMma_QK = typename Traits::TiledMma_QK;
using TiledMma = typename Traits::TiledMma;
using Layout = typename Traits::LayoutConvertor;

using SmemLayoutQ = typename Traits::SmemLayoutQ;
Expand Down Expand Up @@ -47,9 +47,9 @@ void test_mla_traits() {
// print("sQ_rope:"); print(sQ_rope);print("\n");
// print("sKV_rope:"); print(sKV_rope);print("\n");

TiledMma_QK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_slice(0);
// auto tOrVt = thr_mma_qk.partition_fragment_B(sVt);
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(0);
auto tOrVt = thr_mma.partition_fragment_B(sVt);
// TODO: add tests for layout conformance
}

Expand Down