Skip to content

Commit

Permalink
kernel: added ping-pong rmem support for MLA (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Feb 13, 2025
1 parent 5cb2887 commit 38c3a78
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 120 deletions.
1 change: 1 addition & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
fetch-depth: 0

- name: Run clang-format
shell: /usr/bin/bash {0}
run: |
diff=`git-clang-format --extensions="c,h,cc,cp,cpp,c++,cxx,hh,hpp,hxx,inc,cu,cuh" --commit ${{ github.event.pull_request.base.sha }} --diff`
[ "$diff" = "no modified files to format" ] && exit 0
Expand Down
37 changes: 37 additions & 0 deletions src/kernels/attention/cute_extensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,45 @@ constexpr bool
has_with_bool<Copy_Atom,
cute::void_t<decltype(declval<typename Copy_Atom::Traits>()
.with(declval<bool>()))>> = true;

template <typename Layout, typename Shape>
CUTE_HOST_DEVICE constexpr auto with_shape(Layout l, Shape s) {
if constexpr (is_underscore<Shape>::value) {
return l;
} else {
return l.with_shape(s);
}
}

} // namespace detail

// returns a fragment with the a shape (MMA, mma_m, mma_k)
template <typename ThrMMA, class ATensor, typename ShapeM, typename ShapeK>
CUTE_HOST_DEVICE constexpr auto partition_fragment_A(const ThrMMA& thr_mma,
ATensor&& atensor,
const ShapeM& mma_m,
const ShapeK& mma_k) {
auto a = thr_mma.partition_A(atensor);
auto l = get_nonswizzle_portion(a.layout());
auto a_l = make_layout(get<0>(l),
detail::with_shape(get<1>(l), mma_m),
detail::with_shape(get<2>(l), mma_k));
return thr_mma.make_fragment_A(a_l);
}

template <typename ThrMMA, class BTensor, typename ShapeN, typename ShapeK>
CUTE_HOST_DEVICE constexpr auto partition_fragment_B(const ThrMMA& thr_mma,
BTensor&& btensor,
const ShapeN& mma_n,
const ShapeK& mma_k) {
auto b = thr_mma.partition_B(btensor);
auto l = get_nonswizzle_portion(b.layout());
auto b_l = make_layout(get<0>(l),
detail::with_shape(get<1>(l), mma_n),
detail::with_shape(get<2>(l), mma_k));
return thr_mma.make_fragment_B(b_l);
}

template <int... Is, int B, int M, int S, class Offset, class LayoutB>
CUTE_HOST_DEVICE constexpr auto permute(
const ComposedLayout<Swizzle<B, M, S>, Offset, LayoutB>& c) {
Expand Down
Loading

0 comments on commit 38c3a78

Please sign in to comment.