Skip to content

Commit

Permalink
use zfill
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Jan 23, 2025
1 parent 0aba085 commit 31f9c38
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 46 deletions.
50 changes: 24 additions & 26 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_q = [&]() {
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_Q,
tQgQ,
tQsQ,
Expand All @@ -159,30 +159,28 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
// skip zero fill oob for k since mask will mask out oob with -inf
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/true>(gmem_tiled_copy_Q,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
// TODO: skip zero fill oob for v, may have nan issue
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(gmem_tiled_copy_Q,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

TiledMma tiled_mma;
Expand Down Expand Up @@ -302,10 +300,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// wait for smem copy done before gmem copy
__syncthreads();
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/false>(
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/false>(
gmem_tiled_copy_O,
tOsO,
tOgO,
Expand Down
6 changes: 4 additions & 2 deletions src/kernels/attention/attention_traits_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,17 @@ struct AttentionTraitsSM80 {
// g2s tiled copy for q
using GmemTiledCopyQ =
decltype(detail::tiled_copy_selector<
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DType>,
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
DType>,
BLK_K,
kThreadNum>());

// g2s tiled copy for kv
// TODO: choose based on BLK_K and kv cache type
using GmemTiledCopyKV =
decltype(detail::tiled_copy_selector<
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, KV_DType>,
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,
KV_DType>,
BLK_K,
kThreadNum>());

Expand Down
76 changes: 58 additions & 18 deletions src/kernels/attention/cute_extensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a,
return elem_less(get<I>(a), get<I>(b));
}

template <bool EVEN_K,
bool EVEN_MN,
bool ZERO_FILL_MN,
bool ZERO_FILL_K,
class TiledCopy,
template <class Copy_Atom, class TensorS, class TensorD>
CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom,
const TensorS& src,
TensorD&& dst) {
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");

auto has_with_bool = cute::is_valid(
[](auto t) -> void_t<decltype(declval<typename decltype(t)::Traits>()
.with(true))> {},
copy_atom);
if constexpr (has_with_bool) {
constexpr int R = TensorD::rank;
if constexpr (R == 1) { // Dispatch the copy
copy_atom.with(false).call(src, dst);
} else { // Loop over all but the first mode
Tensor src_v = group_modes<1, R>(src);
Tensor dst_v = group_modes<1, R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
copy_atom.with(false).call(src_v(_, i), dst_v(_, i));
}
}
} else {
// just call clear if no with method
clear(dst);
}
}

template <class... CopyArgs, class TensorS, class TensorD>
CUTE_HOST_DEVICE void zfill(const Copy_Atom<CopyArgs...>& copy_atom,
const TensorS& src,
TensorD& dst) {
zfill(copy_atom, src, dst);
}

template <bool EVEN_MN,
bool EVEN_K,
bool ZFILL_MN,
bool ZFILL_K,
class CopyAtom,
class TV,
class Tiler,
class TensorS,
class TensorD,
class TensorC,
class Coord>
CUTE_HOST_DEVICE void safe_copy(
const TiledCopy& tiled_copy,
const TiledCopy<CopyAtom, TV, Tiler>& tiled_copy,
const TensorS& src, // (CPY, CPY_M/N, CPY_K)
TensorD& dst, // (CPY, CPY_M/N, CPY_K)
const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k)
const Coord& max_coord // max_coord(blk_m/n, blk_k)
) {
CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch.");
auto copy_atom = static_cast<const CopyAtom&>(tiled_copy);

if constexpr (!EVEN_MN && !EVEN_K) {
// handle both m/n and k oob
CUTE_UNROLL
Expand All @@ -46,16 +86,16 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki));
copy(copy_atom, src(_, mi, ki), dst(_, mi, ki));
} else {
if constexpr (ZERO_FILL_K) {
clear(dst(_, mi, ki));
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki));
}
}
}
} else {
if constexpr (ZERO_FILL_MN) {
clear(dst(_, mi, _));
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
}
}
Expand All @@ -64,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int mi = 0; mi < size<1>(src); ++mi) {
if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) {
copy(tiled_copy, src(_, mi, _), dst(_, mi, _));
copy(copy_atom, src(_, mi, _), dst(_, mi, _));
} else {
if constexpr (ZERO_FILL_MN) {
clear(dst(_, mi, _));
if constexpr (ZFILL_MN) {
zfill(copy_atom, src(_, mi, _), dst(_, mi, _));
}
}
}
Expand All @@ -76,16 +116,16 @@ CUTE_HOST_DEVICE void safe_copy(
CUTE_UNROLL
for (int ki = 0; ki < size<2>(src); ++ki) {
if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) {
copy(tiled_copy, src(_, _, ki), dst(_, _, ki));
copy(copy_atom, src(_, _, ki), dst(_, _, ki));
} else {
if constexpr (ZERO_FILL_K) {
clear(dst(_, _, ki));
if constexpr (ZFILL_K) {
zfill(copy_atom, src(_, _, ki), dst(_, _, ki));
}
}
}
} else {
// no oob, just copy
copy(tiled_copy, src, dst);
copy(copy_atom, src, dst);
}
}

Expand Down

0 comments on commit 31f9c38

Please sign in to comment.