From 31f9c381ee5151f59159ace69578ba645e5b117d Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Wed, 22 Jan 2025 23:44:12 -0800 Subject: [PATCH] use zfill --- .../attention/attention_kernel_sm80.cuh | 50 ++++++------ src/kernels/attention/attention_traits_sm80.h | 6 +- src/kernels/attention/cute_extensions.cuh | 76 ++++++++++++++----- 3 files changed, 86 insertions(+), 46 deletions(-) diff --git a/src/kernels/attention/attention_kernel_sm80.cuh b/src/kernels/attention/attention_kernel_sm80.cuh index 8d03c17a..7db50ba7 100644 --- a/src/kernels/attention/attention_kernel_sm80.cuh +++ b/src/kernels/attention/attention_kernel_sm80.cuh @@ -143,10 +143,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_q = [&]() { auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - safe_copy( + safe_copy( gmem_tiled_copy_Q, tQgQ, tQsQ, @@ -159,30 +159,28 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_k = [&](int ni) { auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); // skip zero fill oob for k since mask will mask out oob with -inf - safe_copy( - gmem_tiled_copy_Q, - tKgK, - tKsK, - tKcKV, - make_coord(kv_len - ni * kBlockN, head_dim)); + safe_copy(gmem_tiled_copy_Q, + tKgK, + tKsK, + tKcKV, + make_coord(kv_len - ni * kBlockN, head_dim)); }; Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); auto produce_v = [&](int ni) { auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); // TODO: skip zero fill oob for v, may have nan issue - safe_copy( - gmem_tiled_copy_Q, - tVgV, - tVsV, - tKcKV, - make_coord(kv_len - ni * kBlockN, head_dim)); + safe_copy(gmem_tiled_copy_Q, + tVgV, + tVsV, + tKcKV, + make_coord(kv_len - ni * kBlockN, head_dim)); }; TiledMma tiled_mma; @@ -302,10 +300,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // wait for smem copy done before gmem copy __syncthreads(); - safe_copy( + safe_copy( gmem_tiled_copy_O, tOsO, tOgO, diff --git a/src/kernels/attention/attention_traits_sm80.h b/src/kernels/attention/attention_traits_sm80.h index 63d2c58a..24f5a15b 100644 --- a/src/kernels/attention/attention_traits_sm80.h +++ b/src/kernels/attention/attention_traits_sm80.h @@ -159,7 +159,8 @@ struct AttentionTraitsSM80 { // g2s tiled copy for q using GmemTiledCopyQ = decltype(detail::tiled_copy_selector< - Copy_Atom, DType>, + Copy_Atom, + DType>, BLK_K, kThreadNum>()); @@ -167,7 +168,8 @@ struct AttentionTraitsSM80 { // TODO: choose based on BLK_K and kv cache type using GmemTiledCopyKV = decltype(detail::tiled_copy_selector< - Copy_Atom, KV_DType>, + Copy_Atom, + KV_DType>, BLK_K, kThreadNum>()); diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index 1b5316df..e2b6ed00 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -22,22 +22,62 @@ CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a, return elem_less(get(a), get(b)); } -template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD&& dst) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + + auto has_with_bool = cute::is_valid( + [](auto t) -> void_t() + .with(true))> {}, + copy_atom); + if constexpr (has_with_bool) { + constexpr int R = TensorD::rank; + if constexpr (R == 1) { // Dispatch the copy + copy_atom.with(false).call(src, dst); + } else { // Loop over all but the first mode + Tensor src_v = group_modes<1, R>(src); + Tensor dst_v = group_modes<1, R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.with(false).call(src_v(_, i), dst_v(_, i)); + } + } + } else { + // just call clear if no with method + clear(dst); + } +} + +template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD& dst) { + zfill(copy_atom, src, dst); +} + +template CUTE_HOST_DEVICE void safe_copy( - const TiledCopy& tiled_copy, + const TiledCopy& tiled_copy, const TensorS& src, // (CPY, CPY_M/N, CPY_K) TensorD& dst, // (CPY, CPY_M/N, CPY_K) const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k) const Coord& max_coord // max_coord(blk_m/n, blk_k) ) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + auto copy_atom = static_cast(tiled_copy); + if constexpr (!EVEN_MN && !EVEN_K) { // handle both m/n and k oob CUTE_UNROLL @@ -46,16 +86,16 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki)); + copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, mi, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } } } } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -64,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { - copy(tiled_copy, src(_, mi, _), dst(_, mi, _)); + copy(copy_atom, src(_, mi, _), dst(_, mi, _)); } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -76,16 +116,16 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, _, ki), dst(_, _, ki)); + copy(copy_atom, src(_, _, ki), dst(_, _, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, _, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); } } } } else { // no oob, just copy - copy(tiled_copy, src, dst); + copy(copy_atom, src, dst); } }