Skip to content

Commit

Permalink
Merge branch 'main' into mla
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Feb 8, 2025
2 parents e729ea9 + 532f638 commit 3d31c73
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 137 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ jobs:
build_wheel:
strategy:
matrix:
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python: ["3.9", "3.10", "3.11", "3.12"]
cuda: ["11.8", "12.1", "12.4"]
torch: ["2.4.1", "2.5.1"]
exclude: # torch 2.5.1 dropped support for python 3.8
- python: "3.8"
torch: "2.5.1"
torch: ["2.4.1", "2.5.1", "2.6.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/package_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
matrix:
python: ["3.12"]
cuda: ["12.4"]
torch: ["2.5.1"]
torch: ["2.6.0"]
runs-on: [self-hosted, linux, build]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
matrix:
python: ["3.9", "3.10", "3.11", "3.12"]
cuda: ["12.4"]
torch: ["2.5.1"]
torch: ["2.6.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
matrix:
python: ["3.9", "3.10", "3.11", "3.12"]
cuda: ["12.4"]
torch: ["2.5.1"]
torch: ["2.6.0"]
runs-on: [self-hosted, linux, release]
env:
PYTHON_VERSION: ${{ matrix.python }}
Expand Down
20 changes: 10 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,25 @@ if (DEFINED ENV{LIBTORCH_ROOT})
else()
include(FetchContent)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.4)
# download libtorch 2.5.1 with cuda 12.4 from pytorch.org
# download libtorch 2.6.0 with cuda 12.4 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu124.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.5.1%2Bcu124.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu124/libtorch-shared-with-deps-2.6.0%2Bcu124.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.1)
# download libtorch 2.5.1 with cuda 12.1 from pytorch.org
# download libtorch 2.6.0 with cuda 12.1 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu121.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.5.1%2Bcu121.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.6.0%2Bcu121.zip")
endif()
elseif(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 11.8)
# download libtorch 2.5.1 with cuda 11.8 from pytorch.org
# download libtorch 2.6.0 with cuda 11.8 from pytorch.org
if (USE_CXX11_ABI)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.5.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu118.zip")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.5.1%2Bcu118.zip")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-2.6.0%2Bcu118.zip")
endif()
else()
# error out if cuda version is not supported
Expand All @@ -234,7 +234,7 @@ else()
FetchContent_MakeAvailable(libtorch)

find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
message(STATUS "Downloading and using libtorch 2.5.1 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
message(STATUS "Downloading and using libtorch 2.6.0 for cuda ${CUDA_VERSION} at ${libtorch_SOURCE_DIR}")
endif()

# check if USE_CXX11_ABI is set correctly
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ScaleLLM is currently undergoing active development. We are fully committed to c

ScaleLLM is available as a Python Wheel package on PyPI. You can install it using pip:
```bash
# Install scalellm with CUDA 12.4 and Pytorch 2.5.1
# Install scalellm with CUDA 12.4 and Pytorch 2.6.0
pip install -U scalellm
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ScaleLLM is available as a Python Wheel package on `PyPI <https://pypi.org/proje

.. code-block:: bash
# Install ScaleLLM with CUDA 12.4 and Pytorch 2.5.1
# Install ScaleLLM with CUDA 12.4 and Pytorch 2.6.0
$ pip install -U scalellm
Expand Down
18 changes: 18 additions & 0 deletions docs/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ If you want to install ScaleLLM with different versions of CUDA and PyTorch, you

.. tabs::

.. tab:: PyTorch 2.6.0

.. code-block:: bash
$ pip install -U scalellm -i https://whl.vectorch.com/cu124/torch2.6.0/
.. tab:: PyTorch 2.5.1

.. code-block:: bash
Expand All @@ -41,6 +47,12 @@ If you want to install ScaleLLM with different versions of CUDA and PyTorch, you

.. tabs::

.. tab:: PyTorch 2.6.0

.. code-block:: bash
$ pip install -U scalellm -i https://whl.vectorch.com/cu121/torch2.6.0/
.. tab:: PyTorch 2.5.1

.. code-block:: bash
Expand All @@ -57,6 +69,12 @@ If you want to install ScaleLLM with different versions of CUDA and PyTorch, you

.. tabs::

.. tab:: PyTorch 2.6.0

.. code-block:: bash
$ pip install -U scalellm -i https://whl.vectorch.com/cu118/torch2.6.0/
.. tab:: PyTorch 2.5.1

.. code-block:: bash
Expand Down
7 changes: 6 additions & 1 deletion scalellm/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def convert_pickle_to_safetensors(path):
continue

# load the model
model = torch.load(file_path, map_location="cpu")
try:
model = torch.load(file_path, map_location="cpu")
except Exception as e:
print(f"Error loading {filename}: {e}")
continue

if hasattr(model, "state_dict"):
state_dict = model.state_dict()
else:
Expand Down
88 changes: 25 additions & 63 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -299,21 +299,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
return;
}

// ############### Prologue ###############
int n_block_idx = n_block_max - 1;
// produce query: [] => [q]
produce_query();
cp_async_fence();
// produce key: [q] => [q, k]
produce_key(n_block_idx);
cp_async_fence();

// ############### Mainloop ###############
// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrAccS_rc_view =
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));

auto apply_logits_soft_cap = [&](auto& tSrAccS) {
if constexpr (SOFT_CAP) {
CUTE_UNROLL
Expand All @@ -323,7 +308,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
}
};

constexpr int kMMA_M = size<1>(tSrAccS);
constexpr int kMMA_M = size<1>(tOrAccO);
using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M>;
using Mask = Mask<kBlockM, kBlockM, kRowsPerMMA, kMMA_M, ALIBI, LOCAL>;

Expand All @@ -338,12 +323,26 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
sm_scale,
params.alibi_slopes_ptr);

// seperate oob mask iterations for better performance
// ############### Prologue ###############
// produce query: [] => [q]
produce_query();
cp_async_fence();
// produce key: [q] => [q, k]
produce_key(n_block_max - 1);
cp_async_fence();

// ############### Mainloop ###############
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;
const int n_blocks = n_block_max - n_block_min;

// oob mask iterations
CUTE_UNROLL
for (int i = 0; i < n_oob_mask; ++i) {
CUTE_NO_UNROLL
for (int i = 0; i < n_blocks; ++i) {
const int n_block_idx = n_block_max - 1 - i;

// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrAccS_rc_view =
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
clear(tSrAccS);

// wait key, queue: [q, k] => []
Expand All @@ -361,57 +360,20 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
// 1> S = Q@K.T
compute_qk(tSrAccS);

if constexpr (SOFT_CAP) {
apply_logits_soft_cap(tSrAccS);
}
mask.apply(tSrAccS_rc_view, n_block_idx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// wait value, [v] => []
cp_async_wait<0>();
__syncthreads();

// produce next key: [] => [k]
if (n_block_idx > n_block_min) {
produce_key_no_oob(n_block_idx - 1);
}
cp_async_fence();

// 2> O = softmax(S)*V
compute_sv(tSrAccS, tOrAccO);

--n_block_idx;
if (n_block_idx < n_block_min) {
// no more kv blocks to process
break;
}
}

// non-oob mask iterations
CUTE_NO_UNROLL
for (; n_block_idx >= n_block_min; --n_block_idx) {
clear(tSrAccS);

// wait key, queue: [q, k] => []
cp_async_wait<0>();
__syncthreads();

// produce value, [] => [v]
produce_value_no_oob(n_block_idx);
cp_async_fence();

// 1> S = Q@K.T
compute_qk(tSrAccS);

if constexpr (SOFT_CAP) {
apply_logits_soft_cap(tSrAccS);
}
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, n_block_idx);
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// wait value, [v] => []
cp_async_wait<0>();
__syncthreads();
if (i < n_oob_mask) {
mask.apply(tSrAccS_rc_view, n_block_idx);
} else {
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, n_block_idx);
}
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);

// produce next key: [] => [k]
if (n_block_idx > n_block_min) {
Expand Down
16 changes: 2 additions & 14 deletions src/kernels/attention/mha_sm80_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,10 @@
#include "mha_dispatch_sm80.cuh"
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
#include "mha_params.h"
#include "static_dispatch.h"

using namespace llm;

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

void mha_bench_sm80(nvbench::state& state) {
// Collect CUPTI metrics
state.collect_cupti_metrics();
Expand Down Expand Up @@ -82,7 +70,7 @@ void mha_bench_sm80(nvbench::state& state) {
params.sliding_window = sliding_window;

state.exec([&](nvbench::launch& launch) {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_mha_kernel_sm80<cute::half_t, HEAD_DIM>(params, launch.get_stream());
});
});
Expand Down
16 changes: 2 additions & 14 deletions src/kernels/attention/mha_sm80_pagedkv_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,10 @@
#include "mha_dispatch_sm80.cuh"
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
#include "mha_params.h"
#include "static_dispatch.h"

using namespace llm;

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

void mha_bench_sm80(nvbench::state& state) {
// Collect CUPTI metrics
state.collect_cupti_metrics();
Expand Down Expand Up @@ -130,7 +118,7 @@ void mha_bench_sm80(nvbench::state& state) {
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();

state.exec([&](nvbench::launch& launch) {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_mha_kernel_sm80<cute::half_t, HEAD_DIM>(params, launch.get_stream());
});
});
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/attention/mha_traits_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct MHATraitsSM80 {
// Tiled copy for QKV
// g2s tiled copy for q
using GmemTiledCopyQ = decltype(make_tiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>, DType>{},
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DType>{},
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
));
Expand Down
Loading

0 comments on commit 3d31c73

Please sign in to comment.