From 3460e9dbc7e9980280db990294180e5d5581d539 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 9 Sep 2024 22:26:04 -0700 Subject: [PATCH] feat: added pass-in alibi slopes support for flash infer kernel (#334) --- .../attention/flash_infer/attention_kernel.h | 22 ++++++++-------- .../flash_infer/attention_wrapper.cu | 25 ++++++++++++------- .../attention/flash_infer/attention_wrapper.h | 4 +-- .../flash_infer/generate_instantiations.py | 4 +-- tests/kernels/attention/flash_infer_test.py | 20 +++++++++------ tests/kernels/attention/ref_attention.py | 19 ++++++++++++-- 6 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/kernels/attention/flash_infer/attention_kernel.h b/src/kernels/attention/flash_infer/attention_kernel.h index 0a18c739..22b1f399 100644 --- a/src/kernels/attention/flash_infer/attention_kernel.h +++ b/src/kernels/attention/flash_infer/attention_kernel.h @@ -871,7 +871,6 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( IdType* __restrict__ kv_indptr, uint8_t* __restrict__ custom_mask, IdType* __restrict__ qk_indptr, - IdType* __restrict__ q_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, float* __restrict__ lse, @@ -881,7 +880,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( const uint_fastdiv group_size, int32_t maybe_window_left, const float logits_soft_cap, - float sm_scale) { + float sm_scale, + float* __restrict__ alibi_slopes) { static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= (logits_post_hook == LogitsPostHook::kNone @@ -898,7 +898,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; - float alibi_slopes[num_frags_x][2]; + float alibi_slopes_frag[num_frags_x][2]; const uint32_t request_idx = request_indices[bx], qo_tile_idx = q_tile_indices[bx], @@ -994,8 +994,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( const uint32_t qo_head_idx = kv_head_idx * group_size + (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; - alibi_slopes[fx][j] = - get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + alibi_slopes_frag[fx][j] = alibi_slopes[qo_head_idx] * math::log2e; } } } @@ -1126,7 +1125,6 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( logits_soft_cap); if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - // TODO(Zihao): handle the case that q_offset is specified apply_alibi_bias( qo_packed_idx_base, chunk_start + (iter * num_warps_z + @@ -1134,7 +1132,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( num_frags_z * 16, int(kv_len) - int(qo_len), group_size, - alibi_slopes, + alibi_slopes_frag, s_frag); } // apply mask @@ -1277,7 +1275,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, IdType* kv_tile_indices, IdType* q_indptr, IdType* kv_indptr, - IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, @@ -1295,6 +1292,7 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, int32_t window_left, float logits_soft_cap, float sm_scale, + float* alibi_slopes, cudaStream_t stream) { #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same_v) { @@ -1402,7 +1400,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, (void*)&kv_indptr, (void*)&custom_mask, (void*)&qk_indptr, - (void*)&q_offset, (void*)&o_indptr, (void*)&o, (void*)&lse, @@ -1412,7 +1409,8 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, (void*)&group_size_fastdiv, (void*)&window_left, (void*)&logits_soft_cap, - (void*)&sm_scale}; + (void*)&sm_scale, + (void*)&alibi_slopes}; FLASHINFER_CUDA_CALL(cudaLaunchKernel( (void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { @@ -1426,7 +1424,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, (void*)&kv_indptr, (void*)&custom_mask, (void*)&qk_indptr, - (void*)&q_offset, (void*)&o_indptr, (void*)&tmp_v, (void*)&tmp_s, @@ -1436,7 +1433,8 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, (void*)&group_size_fastdiv, (void*)&window_left, (void*)&logits_soft_cap, - (void*)&sm_scale}; + (void*)&sm_scale, + (void*)&alibi_slopes}; FLASHINFER_CUDA_CALL(cudaLaunchKernel( (void*)kernel, nblks, nthrs, args, smem_size, stream)); FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, diff --git a/src/kernels/attention/flash_infer/attention_wrapper.cu b/src/kernels/attention/flash_infer/attention_wrapper.cu index 32fd861f..67ccfd11 100644 --- a/src/kernels/attention/flash_infer/attention_wrapper.cu +++ b/src/kernels/attention/flash_infer/attention_wrapper.cu @@ -29,7 +29,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, IdType* kv_tile_indices, IdType* q_indptr, IdType* kv_indptr, - IdType* q_offset, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* qk_indptr, @@ -47,6 +46,7 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, int32_t window_left, float logits_soft_cap, float sm_scale, + float* alibi_slopes, cudaStream_t stream); template paged_kv, uint8_t* custom_mask, IdType* qk_indptr, @@ -71,6 +70,7 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler, int32_t window_left, float logits_soft_cap, float sm_scale, + float* alibi_slopes, cudaStream_t stream) { DTypeOut* tmp_v = nullptr; float* tmp_s = nullptr; @@ -110,7 +110,6 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler, kv_tile_indices, q_indptr, kv_indptr, - q_offset, paged_kv, custom_mask, qk_indptr, @@ -128,6 +127,7 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler, window_left, logits_soft_cap, sm_scale, + alibi_slopes, stream); }); return cudaSuccess; @@ -200,10 +200,10 @@ torch::Tensor BatchPrefillWrapper::Run( std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - unsigned int pos_encoding_mode, int window_left, float logits_soft_cap, - float sm_scale) { + float sm_scale, + std::optional alibi_slopes) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(kv_indptr); @@ -226,8 +226,8 @@ torch::Tensor BatchPrefillWrapper::Run( CHECK_DIM(4, paged_k_cache.value()); CHECK_DIM(4, paged_v_cache.value()); - CHECK_DIM(1, paged_kv_indptr); // (B + 1,) - CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) + CHECK_DIM(1, paged_kv_indptr); // (B + 1,) + CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); @@ -255,6 +255,9 @@ torch::Tensor BatchPrefillWrapper::Run( TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; + const auto pos_encoding_mode = alibi_slopes.has_value() + ? PosEncodingMode::kALiBi + : PosEncodingMode::kNone; auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = paged_k_cache->scalar_type(); @@ -288,7 +291,6 @@ torch::Tensor BatchPrefillWrapper::Run( static_cast(q.data_ptr()), qo_indptr.data_ptr(), kv_indptr.data_ptr(), - /*q_offset=*/nullptr, paged_kv, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, @@ -298,6 +300,9 @@ torch::Tensor BatchPrefillWrapper::Run( window_left, logits_soft_cap, sm_scale, + alibi_slopes.has_value() + ? alibi_slopes->data_ptr() + : nullptr, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with " @@ -342,7 +347,6 @@ torch::Tensor BatchPrefillWrapper::Run( static_cast(q.data_ptr()), qo_indptr.data_ptr(), kv_indptr.data_ptr(), - /*q_offset=*/nullptr, paged_kv, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, @@ -352,6 +356,9 @@ torch::Tensor BatchPrefillWrapper::Run( window_left, logits_soft_cap, sm_scale, + alibi_slopes.has_value() + ? alibi_slopes->data_ptr() + : nullptr, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed " diff --git a/src/kernels/attention/flash_infer/attention_wrapper.h b/src/kernels/attention/flash_infer/attention_wrapper.h index b3c87812..611b6fbc 100644 --- a/src/kernels/attention/flash_infer/attention_wrapper.h +++ b/src/kernels/attention/flash_infer/attention_wrapper.h @@ -37,10 +37,10 @@ class BatchPrefillWrapper { std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - unsigned int pos_encoding_mode, int window_left, float logits_soft_cap, - float sm_scale); + float sm_scale, + std::optional alibi_slopes); private: std::shared_ptr handler_; diff --git a/src/kernels/attention/flash_infer/generate_instantiations.py b/src/kernels/attention/flash_infer/generate_instantiations.py index 05dab2f5..310bcc82 100755 --- a/src/kernels/attention/flash_infer/generate_instantiations.py +++ b/src/kernels/attention/flash_infer/generate_instantiations.py @@ -58,12 +58,12 @@ {DType}, {IDType}>( {QDType}* q, {IDType}* request_indices, {IDType}* q_tile_indices, {IDType}* kv_tile_indices, - {IDType}* q_indptr, {IDType}* kv_indptr, {IDType}* q_offset, + {IDType}* q_indptr, {IDType}* kv_indptr, paged_kv_t<{KVDType}, {IDType}> paged_kv, uint8_t* custom_mask, {IDType}* qk_indptr, {IDType}* o_indptr, {DType}* o, {DType}* tmp_v, float* tmp_s, float* lse, {IDType}* merge_indptr, bool* block_valid_mask, {IDType}* kv_chunk_size_ptr, uint32_t max_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, int32_t window_left, - float logits_soft_cap, float sm_scale, cudaStream_t stream); + float logits_soft_cap, float sm_scale, float* alibi_slopes, cudaStream_t stream); """ FILE_TEMPLATE = """#include "attention_kernel.h" diff --git a/tests/kernels/attention/flash_infer_test.py b/tests/kernels/attention/flash_infer_test.py index 5e78ad70..9285d68a 100644 --- a/tests/kernels/attention/flash_infer_test.py +++ b/tests/kernels/attention/flash_infer_test.py @@ -11,11 +11,11 @@ @pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)]) @pytest.mark.parametrize("head_size", [64, 128, 256]) @pytest.mark.parametrize("n_blocks", [100]) -@pytest.mark.parametrize("block_size", [4, 8, 16, 32]) +@pytest.mark.parametrize("block_size", [4, 8, 16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0, 50.0]) -@pytest.mark.parametrize("sliding_window", [-1, 40]) -# @pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0]) +@pytest.mark.parametrize("sliding_window", [-1, 50]) +@pytest.mark.parametrize("alibi", [False, True]) @torch.inference_mode def test_flashinfer_varlen_masked_self_attention( seq_lens: List[Tuple[int, int]], @@ -26,6 +26,7 @@ def test_flashinfer_varlen_masked_self_attention( block_size: int, logits_soft_cap: float, sliding_window: int, + alibi: bool, ) -> None: torch.set_default_device("cuda") @@ -89,7 +90,8 @@ def test_flashinfer_varlen_masked_self_attention( empty_q_data, ) - pos_encoding_mode = 0 + alibi_slopes = torch.randn(n_heads, dtype=torch.float32) if alibi else None + output = wrapper.run( query, qo_indptr, @@ -98,10 +100,10 @@ def test_flashinfer_varlen_masked_self_attention( value_cache, paged_kv_indptr, paged_kv_indices, - pos_encoding_mode, sliding_window, logits_soft_cap, sm_scale, + alibi_slopes, ) ref_output = varlen_masked_self_attention( @@ -114,9 +116,13 @@ def test_flashinfer_varlen_masked_self_attention( sm_scale=sm_scale, logits_soft_cap=logits_soft_cap, sliding_window=sliding_window, + alibi_slopes=alibi_slopes, ) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + if alibi and dtype == torch.bfloat16: + torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-2) + else: + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) if __name__ == "__main__": diff --git a/tests/kernels/attention/ref_attention.py b/tests/kernels/attention/ref_attention.py index f7e6282f..2dd1c829 100644 --- a/tests/kernels/attention/ref_attention.py +++ b/tests/kernels/attention/ref_attention.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch @@ -7,6 +7,7 @@ def masked_self_attention( query: torch.Tensor, # [q_len, n_heads, head_dim] key: torch.Tensor, # [kv_len, n_heads, head_dim] value: torch.Tensor, # [kv_len, n_heads, head_dim] + alibi_bias: Optional[torch.Tensor], # [n_heads, 1, kv_len] mask: torch.Tensor, # [n_heads, q_len, kv_len] sm_scale: float, logits_soft_cap: float, @@ -20,6 +21,10 @@ def masked_self_attention( # apply soft_cap if logits_soft_cap > 0.0: scores = torch.tanh(scores / logits_soft_cap) * logits_soft_cap + + # apply alibi bias + if alibi_bias is not None: + scores += alibi_bias # apply mask scores.masked_fill_(mask == 0, float("-inf")) @@ -40,6 +45,7 @@ def varlen_masked_self_attention( sm_scale: float, logits_soft_cap: float = 0.0, sliding_window: int = -1, + alibi_slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert key_cache.shape == value_cache.shape @@ -84,12 +90,21 @@ def varlen_masked_self_attention( # returns the lower triangular part of a matrix mask = mask.tril(diagonal=kv_len - q_len).to(query) - # TODO: add alibi bias support + # calculate alibi attention bias + alibi_bias = None + if alibi_slopes is not None: + assert alibi_slopes.shape == (n_heads,) + # since it's causal mask, we can just use [0, 1, ...,, kv_len) + distance = torch.arange(kv_len, dtype=torch.float32) + # [n_heads, 1, kv_len] + alibi_bias = distance.view(1, 1, -1) * alibi_slopes.view(n_heads, 1, 1) + out = masked_self_attention( query=q, key=k, value=v, + alibi_bias=alibi_bias, mask=mask, sm_scale=sm_scale, logits_soft_cap=logits_soft_cap,