From 7044f667a804708f32d64d1b4d8bbf55c0953f8f Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 6 Feb 2025 17:37:51 -0800 Subject: [PATCH 1/2] kernel: fix register spilling issue for attention head_dim=256 (#397) --- src/kernels/attention/mha_kernel_sm80.cuh | 88 ++++++------------- src/kernels/attention/mha_sm80_bench.cu | 16 +--- .../attention/mha_sm80_pagedkv_bench.cu | 16 +--- src/kernels/attention/mha_traits_sm80.h | 2 +- src/kernels/attention/online_softmax.cuh | 65 +++++++++----- 5 files changed, 72 insertions(+), 115 deletions(-) diff --git a/src/kernels/attention/mha_kernel_sm80.cuh b/src/kernels/attention/mha_kernel_sm80.cuh index 3e216e15..46e3a3ab 100644 --- a/src/kernels/attention/mha_kernel_sm80.cuh +++ b/src/kernels/attention/mha_kernel_sm80.cuh @@ -299,21 +299,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { return; } - // ############### Prologue ############### - int n_block_idx = n_block_max - 1; - // produce query: [] => [q] - produce_query(); - cp_async_fence(); - // produce key: [q] => [q, k] - produce_key(n_block_idx); - cp_async_fence(); - - // ############### Mainloop ############### - // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); - auto tSrAccS_rc_view = - make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); - auto apply_logits_soft_cap = [&](auto& tSrAccS) { if constexpr (SOFT_CAP) { CUTE_UNROLL @@ -323,7 +308,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { } }; - constexpr int kMMA_M = size<1>(tSrAccS); + constexpr int kMMA_M = size<1>(tOrAccO); using Softmax = OnlineSoftmax; using Mask = Mask; @@ -338,12 +323,26 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { sm_scale, params.alibi_slopes_ptr); - // seperate oob mask iterations for better performance + // ############### Prologue ############### + // produce query: [] => [q] + produce_query(); + cp_async_fence(); + // produce key: [q] => [q, k] + produce_key(n_block_max - 1); + cp_async_fence(); + + // ############### Mainloop ############### constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; + const int n_blocks = n_block_max - n_block_min; - // oob mask iterations - CUTE_UNROLL - for (int i = 0; i < n_oob_mask; ++i) { + CUTE_NO_UNROLL + for (int i = 0; i < n_blocks; ++i) { + const int n_block_idx = n_block_max - 1 - i; + + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrAccS_rc_view = + make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); clear(tSrAccS); // wait key, queue: [q, k] => [] @@ -361,57 +360,20 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // 1> S = Q@K.T compute_qk(tSrAccS); - if constexpr (SOFT_CAP) { - apply_logits_soft_cap(tSrAccS); - } - mask.apply(tSrAccS_rc_view, n_block_idx); - softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - // wait value, [v] => [] cp_async_wait<0>(); __syncthreads(); - // produce next key: [] => [k] - if (n_block_idx > n_block_min) { - produce_key_no_oob(n_block_idx - 1); - } - cp_async_fence(); - - // 2> O = softmax(S)*V - compute_sv(tSrAccS, tOrAccO); - - --n_block_idx; - if (n_block_idx < n_block_min) { - // no more kv blocks to process - break; - } - } - - // non-oob mask iterations - CUTE_NO_UNROLL - for (; n_block_idx >= n_block_min; --n_block_idx) { - clear(tSrAccS); - - // wait key, queue: [q, k] => [] - cp_async_wait<0>(); - __syncthreads(); - - // produce value, [] => [v] - produce_value_no_oob(n_block_idx); - cp_async_fence(); - - // 1> S = Q@K.T - compute_qk(tSrAccS); - if constexpr (SOFT_CAP) { apply_logits_soft_cap(tSrAccS); } - mask.apply(tSrAccS_rc_view, n_block_idx); - softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - // wait value, [v] => [] - cp_async_wait<0>(); - __syncthreads(); + if (i < n_oob_mask) { + mask.apply(tSrAccS_rc_view, n_block_idx); + } else { + mask.apply(tSrAccS_rc_view, n_block_idx); + } + softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); // produce next key: [] => [k] if (n_block_idx > n_block_min) { diff --git a/src/kernels/attention/mha_sm80_bench.cu b/src/kernels/attention/mha_sm80_bench.cu index 18eb7147..be41448c 100644 --- a/src/kernels/attention/mha_sm80_bench.cu +++ b/src/kernels/attention/mha_sm80_bench.cu @@ -7,22 +7,10 @@ #include "mha_dispatch_sm80.cuh" #include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" +#include "static_dispatch.h" using namespace llm; -#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ - [&] { \ - if (HEAD_DIM_V <= 64) { \ - constexpr static int HEAD_DIM_NAME = 64; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V <= 128) { \ - constexpr static int HEAD_DIM_NAME = 128; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() - void mha_bench_sm80(nvbench::state& state) { // Collect CUPTI metrics state.collect_cupti_metrics(); @@ -82,7 +70,7 @@ void mha_bench_sm80(nvbench::state& state) { params.sliding_window = sliding_window; state.exec([&](nvbench::launch& launch) { - DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { run_mha_kernel_sm80(params, launch.get_stream()); }); }); diff --git a/src/kernels/attention/mha_sm80_pagedkv_bench.cu b/src/kernels/attention/mha_sm80_pagedkv_bench.cu index 4a1b40ac..329601d2 100644 --- a/src/kernels/attention/mha_sm80_pagedkv_bench.cu +++ b/src/kernels/attention/mha_sm80_pagedkv_bench.cu @@ -8,22 +8,10 @@ #include "mha_dispatch_sm80.cuh" #include "mha_kernel_sm80.cuh" // IWYU pragma: keep #include "mha_params.h" +#include "static_dispatch.h" using namespace llm; -#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ - [&] { \ - if (HEAD_DIM_V <= 64) { \ - constexpr static int HEAD_DIM_NAME = 64; \ - return __VA_ARGS__(); \ - } else if (HEAD_DIM_V <= 128) { \ - constexpr static int HEAD_DIM_NAME = 128; \ - return __VA_ARGS__(); \ - } else { \ - assert(false); \ - } \ - }() - void mha_bench_sm80(nvbench::state& state) { // Collect CUPTI metrics state.collect_cupti_metrics(); @@ -130,7 +118,7 @@ void mha_bench_sm80(nvbench::state& state) { params.block_cu_lens = block_cu_lens.const_data_ptr(); state.exec([&](nvbench::launch& launch) { - DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { run_mha_kernel_sm80(params, launch.get_stream()); }); }); diff --git a/src/kernels/attention/mha_traits_sm80.h b/src/kernels/attention/mha_traits_sm80.h index 6299d80e..3866b793 100644 --- a/src/kernels/attention/mha_traits_sm80.h +++ b/src/kernels/attention/mha_traits_sm80.h @@ -93,7 +93,7 @@ struct MHATraitsSM80 { // Tiled copy for QKV // g2s tiled copy for q using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, + Copy_Atom, DType>{}, GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) Layout>{} // Val layout: 8 vals per read )); diff --git a/src/kernels/attention/online_softmax.cuh b/src/kernels/attention/online_softmax.cuh index 36a01705..d2371fd4 100644 --- a/src/kernels/attention/online_softmax.cuh +++ b/src/kernels/attention/online_softmax.cuh @@ -52,53 +52,72 @@ struct OnlineSoftmax { // computes the softmax scores and rescales the output // - score = exp(score - row_max`) - // - O = O * s_scale + // - o = o * s_scale // - internal: row_sum = row_sum * s_scale + row_sum` template CUTE_DEVICE void rescale(FragmentS& rAccS, FragmentO& rAccO) { + // row_max = max(row_max, scores) + FragmentT pre_row_max; + cute::copy(row_max_, pre_row_max); CUTE_UNROLL for (int si = 0; si < size<0>(rAccS); ++si) { - // rowmax across 4 threads - float cur_rowmax = row_max_(si); + float row_max = row_max_(si); + // rowmax within a thread CUTE_UNROLL for (int sj = 0; sj < size<1>(rAccS); ++sj) { - cur_rowmax = max(cur_rowmax, rAccS(si, sj)); + row_max = max(row_max, rAccS(si, sj)); } - cur_rowmax = detail::group_reduce_max<4>(cur_rowmax); + // rowmax across 4 threads + row_max_(si) = detail::group_reduce_max<4>(row_max); + } - // scores = exp(scores - row_max) - const float rowmax_scale = cur_rowmax * sm_scale_; - float cur_rowsum = 0; + // o = o * s_scale + CUTE_UNROLL + for (int si = 0; si < size<0>(rAccO); ++si) { + const float s_scale = + ptx::exp2((pre_row_max(si) - row_max_(si)) * sm_scale_); + CUTE_UNROLL + for (int sj = 0; sj < size<1>(rAccO); ++sj) { + rAccO(si, sj) *= s_scale; + } + } + + // scores = exp(scores - row_max) + CUTE_UNROLL + for (int si = 0; si < size<0>(rAccS); ++si) { + const float rowmax_scale = row_max_(si) * sm_scale_; CUTE_UNROLL for (int sj = 0; sj < size<1>(rAccS); sj++) { rAccS(si, sj) = ptx::exp2(rAccS(si, sj) * sm_scale_ - rowmax_scale); - cur_rowsum += rAccS(si, sj); } + } - // scores_scale = exp(max - cur_rowmax) - const float scores_scale = - ptx::exp2(row_max_(si) * sm_scale_ - rowmax_scale); - // o_2 = o_1 * s_scale + // row_sum = row_sum * s_scale + row_sum` + CUTE_UNROLL + for (int si = 0; si < size<0>(rAccS); ++si) { + const float s_scale = + ptx::exp2((pre_row_max(si) - row_max_(si)) * sm_scale_); + row_sum_(si) *= s_scale; CUTE_UNROLL - for (int sj = 0; sj < size<1>(rAccO); ++sj) { - rAccO(si, sj) *= scores_scale; + for (int sj = 0; sj < size<1>(rAccS); sj++) { + // rowsum within a thread + row_sum_(si) += rAccS(si, sj); } - - // update row_max and row_sum - row_max_(si) = cur_rowmax; - // s_2 = s_1 * s_scale + row_sum - row_sum_(si) = row_sum_(si) * scores_scale + cur_rowsum; } } - // finalizes the softmax computation with O = O / row_sum + // finalizes the softmax computation with o = o / row_sum template CUTE_DEVICE void finalize(FragmentO& rAccO) { CUTE_UNROLL - for (int oi = 0; oi < size<0>(rAccO); ++oi) { + for (int i = 0; i < size(row_sum_); ++i) { // rowsum across 4 threads - row_sum_(oi) = detail::group_reduce_sum<4>(row_sum_(oi)); + row_sum_(i) = detail::group_reduce_sum<4>(row_sum_(i)); + } + // o = o / row_sum + CUTE_UNROLL + for (int oi = 0; oi < size<0>(rAccO); ++oi) { CUTE_UNROLL for (int oj = 0; oj < size<1>(rAccO); ++oj) { rAccO(oi, oj) *= ptx::rcp(row_sum_(oi)); From 532f6383bbabacc257db4c1f4ea860d2532039f9 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Thu, 6 Feb 2025 19:24:40 -0800 Subject: [PATCH 2/2] upgrade libtorch to 2.6.0 and cutlass to 3.8.0 (#398) --- .github/workflows/build_wheel.yml | 7 ++----- .github/workflows/package_test.yml | 2 +- .github/workflows/publish_wheel.yml | 2 +- .github/workflows/release_test.yml | 2 +- CMakeLists.txt | 20 ++++++++++---------- README.md | 2 +- docs/source/index.rst | 2 +- docs/source/quick_start.rst | 18 ++++++++++++++++++ scalellm/downloader.py | 7 ++++++- third_party/cutlass | 2 +- 10 files changed, 42 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index db0b5eef..1957cc89 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -22,12 +22,9 @@ jobs: build_wheel: strategy: matrix: - python: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python: ["3.9", "3.10", "3.11", "3.12"] cuda: ["11.8", "12.1", "12.4"] - torch: ["2.4.1", "2.5.1"] - exclude: # torch 2.5.1 dropped support for python 3.8 - - python: "3.8" - torch: "2.5.1" + torch: ["2.4.1", "2.5.1", "2.6.0"] runs-on: [self-hosted, linux, release] env: PYTHON_VERSION: ${{ matrix.python }} diff --git a/.github/workflows/package_test.yml b/.github/workflows/package_test.yml index 3df38919..2c6a6419 100644 --- a/.github/workflows/package_test.yml +++ b/.github/workflows/package_test.yml @@ -40,7 +40,7 @@ jobs: matrix: python: ["3.12"] cuda: ["12.4"] - torch: ["2.5.1"] + torch: ["2.6.0"] runs-on: [self-hosted, linux, build] env: PYTHON_VERSION: ${{ matrix.python }} diff --git a/.github/workflows/publish_wheel.yml b/.github/workflows/publish_wheel.yml index 20b2c4c6..4b76116b 100644 --- a/.github/workflows/publish_wheel.yml +++ b/.github/workflows/publish_wheel.yml @@ -23,7 +23,7 @@ jobs: matrix: python: ["3.9", "3.10", "3.11", "3.12"] cuda: ["12.4"] - torch: ["2.5.1"] + torch: ["2.6.0"] runs-on: [self-hosted, linux, release] env: PYTHON_VERSION: ${{ matrix.python }} diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index aa3bd8cf..2ec32aad 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -21,7 +21,7 @@ jobs: matrix: python: ["3.9", "3.10", "3.11", "3.12"] cuda: ["12.4"] - torch: ["2.5.1"] + torch: ["2.6.0"] runs-on: [self-hosted, linux, release] env: PYTHON_VERSION: ${{ matrix.python }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 3adabde1..aa28bf84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,25 +196,25 @@ if (DEFINED ENV{LIBTORCH_ROOT}) else() include(FetchContent) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.4) - # download libtorch 2.5.1 with cuda 12.4 from pytorch.org + # download libtorch 2.6.0 with cuda 12.4 from pytorch.org if (USE_CXX11_ABI) - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu124.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu124.zip") else() - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.1%2Bcu124.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.6.0%2Bcu124.zip") endif() elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.1) - # download libtorch 2.5.1 with cuda 12.1 from pytorch.org + # download libtorch 2.6.0 with cuda 12.1 from pytorch.org if (USE_CXX11_ABI) - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu121.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu121.zip") else() - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.5.1%2Bcu121.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.6.0%2Bcu121.zip") endif() elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.8) - # download libtorch 2.5.1 with cuda 11.8 from pytorch.org + # download libtorch 2.6.0 with cuda 11.8 from pytorch.org if (USE_CXX11_ABI) - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu118.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu118.zip") else() - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.5.1%2Bcu118.zip") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.6.0%2Bcu118.zip") endif() else() # error out if cuda version is not supported @@ -234,7 +234,7 @@ else() FetchContent_MakeAvailable(libtorch) find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH) - message(STATUS "Downloading and using libtorch 2.5.1 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}") + message(STATUS "Downloading and using libtorch 2.6.0 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}") endif() # check if USE_CXX11_ABI is set correctly diff --git a/README.md b/README.md index bc1488ca..78decea1 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ ScaleLLM is currently undergoing active development. We are fully committed to c ScaleLLM is available as a Python Wheel package on PyPI. You can install it using pip: ```bash -# Install scalellm with CUDA 12.4 and Pytorch 2.5.1 +# Install scalellm with CUDA 12.4 and Pytorch 2.6.0 pip install -U scalellm ``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 612bfec1..3c5e30a7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,7 +12,7 @@ ScaleLLM is available as a Python Wheel package on `PyPI