Skip to content

Commit

Permalink
feat: added pass-in alibi slopes support for flash infer kernel (vect…
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Sep 10, 2024
1 parent 9131d8a commit 3460e9d
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 34 deletions.
22 changes: 10 additions & 12 deletions src/kernels/attention/flash_infer/attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,6 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel(
IdType* __restrict__ kv_indptr,
uint8_t* __restrict__ custom_mask,
IdType* __restrict__ qk_indptr,
IdType* __restrict__ q_offset,
IdType* __restrict__ o_indptr,
DTypeOut* __restrict__ o,
float* __restrict__ lse,
Expand All @@ -881,7 +880,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel(
const uint_fastdiv group_size,
int32_t maybe_window_left,
const float logits_soft_cap,
float sm_scale) {
float sm_scale,
float* __restrict__ alibi_slopes) {
static_assert(sizeof(DTypeQ) == 2);
static_assert(sizeof(DTypeOut) == 2);
sm_scale *= (logits_post_hook == LogitsPostHook::kNone
Expand All @@ -898,7 +898,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel(
}
const uint32_t num_kv_heads = gridDim.z,
num_qo_heads = num_kv_heads * group_size;
float alibi_slopes[num_frags_x][2];
float alibi_slopes_frag[num_frags_x][2];

const uint32_t request_idx = request_indices[bx],
qo_tile_idx = q_tile_indices[bx],
Expand Down Expand Up @@ -994,8 +994,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel(
const uint32_t qo_head_idx =
kv_head_idx * group_size +
(qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size;
alibi_slopes[fx][j] =
get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
alibi_slopes_frag[fx][j] = alibi_slopes[qo_head_idx] * math::log2e;
}
}
}
Expand Down Expand Up @@ -1126,15 +1125,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel(
logits_soft_cap);

if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
// TODO(Zihao): handle the case that q_offset is specified
apply_alibi_bias<num_frags_x, num_frags_z>(
qo_packed_idx_base,
chunk_start + (iter * num_warps_z +
get_warp_idx_z<num_warps_x, num_warps_z>()) *
num_frags_z * 16,
int(kv_len) - int(qo_len),
group_size,
alibi_slopes,
alibi_slopes_frag,
s_frag);
}
// apply mask
Expand Down Expand Up @@ -1277,7 +1275,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
IdType* kv_tile_indices,
IdType* q_indptr,
IdType* kv_indptr,
IdType* q_offset,
paged_kv_t<DTypeKV, IdType> paged_kv,
uint8_t* custom_mask,
IdType* qk_indptr,
Expand All @@ -1295,6 +1292,7 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
int32_t window_left,
float logits_soft_cap,
float sm_scale,
float* alibi_slopes,
cudaStream_t stream) {
#if (__CUDA_ARCH__ < 800)
if constexpr (std::is_same_v<DTypeQ, nv_bfloat16>) {
Expand Down Expand Up @@ -1402,7 +1400,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
(void*)&kv_indptr,
(void*)&custom_mask,
(void*)&qk_indptr,
(void*)&q_offset,
(void*)&o_indptr,
(void*)&o,
(void*)&lse,
Expand All @@ -1412,7 +1409,8 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
(void*)&group_size_fastdiv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale};
(void*)&sm_scale,
(void*)&alibi_slopes};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
Expand All @@ -1426,7 +1424,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
(void*)&kv_indptr,
(void*)&custom_mask,
(void*)&qk_indptr,
(void*)&q_offset,
(void*)&o_indptr,
(void*)&tmp_v,
(void*)&tmp_s,
Expand All @@ -1436,7 +1433,8 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
(void*)&group_size_fastdiv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale};
(void*)&sm_scale,
(void*)&alibi_slopes};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v,
Expand Down
25 changes: 16 additions & 9 deletions src/kernels/attention/flash_infer/attention_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
IdType* kv_tile_indices,
IdType* q_indptr,
IdType* kv_indptr,
IdType* q_offset,
paged_kv_t<DTypeKV, IdType> paged_kv,
uint8_t* custom_mask,
IdType* qk_indptr,
Expand All @@ -47,6 +46,7 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
int32_t window_left,
float logits_soft_cap,
float sm_scale,
float* alibi_slopes,
cudaStream_t stream);

template <uint32_t HEAD_DIM,
Expand All @@ -61,7 +61,6 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
DTypeQ* q,
IdType* q_indptr,
IdType* kv_indptr,
IdType* q_offset,
paged_kv_t<DTypeKV, IdType> paged_kv,
uint8_t* custom_mask,
IdType* qk_indptr,
Expand All @@ -71,6 +70,7 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
int32_t window_left,
float logits_soft_cap,
float sm_scale,
float* alibi_slopes,
cudaStream_t stream) {
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
Expand Down Expand Up @@ -110,7 +110,6 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
kv_tile_indices,
q_indptr,
kv_indptr,
q_offset,
paged_kv,
custom_mask,
qk_indptr,
Expand All @@ -128,6 +127,7 @@ cudaError_t mha_varlen_wrapper_dispatch(BatchPrefillHandler* handler,
window_left,
logits_soft_cap,
sm_scale,
alibi_slopes,
stream);
});
return cudaSuccess;
Expand Down Expand Up @@ -200,10 +200,10 @@ torch::Tensor BatchPrefillWrapper::Run(
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices,
unsigned int pos_encoding_mode,
int window_left,
float logits_soft_cap,
float sm_scale) {
float sm_scale,
std::optional<torch::Tensor> alibi_slopes) {
CHECK_INPUT(q);
CHECK_INPUT(qo_indptr);
CHECK_INPUT(kv_indptr);
Expand All @@ -226,8 +226,8 @@ torch::Tensor BatchPrefillWrapper::Run(
CHECK_DIM(4, paged_k_cache.value());
CHECK_DIM(4, paged_v_cache.value());

CHECK_DIM(1, paged_kv_indptr); // (B + 1,)
CHECK_DIM(1, paged_kv_indices); // (nnz_kv,)
CHECK_DIM(1, paged_kv_indptr); // (B + 1,)
CHECK_DIM(1, paged_kv_indices); // (nnz_kv,)
int64_t batch_size = qo_indptr.size(0) - 1;
int64_t nnz_qo = q.size(0);
int64_t num_qo_heads = q.size(1);
Expand Down Expand Up @@ -255,6 +255,9 @@ torch::Tensor BatchPrefillWrapper::Run(
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
const LogitsPostHook logits_post_hook =
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;
const auto pos_encoding_mode = alibi_slopes.has_value()
? PosEncodingMode::kALiBi
: PosEncodingMode::kNone;

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache->scalar_type();
Expand Down Expand Up @@ -288,7 +291,6 @@ torch::Tensor BatchPrefillWrapper::Run(
static_cast<c_type*>(q.data_ptr()),
qo_indptr.data_ptr<int32_t>(),
kv_indptr.data_ptr<int32_t>(),
/*q_offset=*/nullptr,
paged_kv,
/*custom_mask=*/nullptr,
/*qk_indptr=*/nullptr,
Expand All @@ -298,6 +300,9 @@ torch::Tensor BatchPrefillWrapper::Run(
window_left,
logits_soft_cap,
sm_scale,
alibi_slopes.has_value()
? alibi_slopes->data_ptr<float>()
: nullptr,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed with "
Expand Down Expand Up @@ -342,7 +347,6 @@ torch::Tensor BatchPrefillWrapper::Run(
static_cast<q_type*>(q.data_ptr()),
qo_indptr.data_ptr<int32_t>(),
kv_indptr.data_ptr<int32_t>(),
/*q_offset=*/nullptr,
paged_kv,
/*custom_mask=*/nullptr,
/*qk_indptr=*/nullptr,
Expand All @@ -352,6 +356,9 @@ torch::Tensor BatchPrefillWrapper::Run(
window_left,
logits_soft_cap,
sm_scale,
alibi_slopes.has_value()
? alibi_slopes->data_ptr<float>()
: nullptr,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithPagedKVCache failed "
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/attention/flash_infer/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class BatchPrefillWrapper {
std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices,
unsigned int pos_encoding_mode,
int window_left,
float logits_soft_cap,
float sm_scale);
float sm_scale,
std::optional<torch::Tensor> alibi_slopes);

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/attention/flash_infer/generate_instantiations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
{DType},
{IDType}>(
{QDType}* q, {IDType}* request_indices, {IDType}* q_tile_indices, {IDType}* kv_tile_indices,
{IDType}* q_indptr, {IDType}* kv_indptr, {IDType}* q_offset,
{IDType}* q_indptr, {IDType}* kv_indptr,
paged_kv_t<{KVDType}, {IDType}> paged_kv, uint8_t* custom_mask,
{IDType}* qk_indptr, {IDType}* o_indptr, {DType}* o, {DType}* tmp_v, float* tmp_s, float* lse,
{IDType}* merge_indptr, bool* block_valid_mask, {IDType}* kv_chunk_size_ptr, uint32_t max_num_rows,
uint32_t num_qo_heads, uint32_t padded_batch_size, int32_t window_left,
float logits_soft_cap, float sm_scale, cudaStream_t stream);
float logits_soft_cap, float sm_scale, float* alibi_slopes, cudaStream_t stream);
"""

FILE_TEMPLATE = """#include "attention_kernel.h"
Expand Down
20 changes: 13 additions & 7 deletions tests/kernels/attention/flash_infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
@pytest.mark.parametrize("num_heads", [(8, 8), (8, 4), (8, 2), (8, 1)])
@pytest.mark.parametrize("head_size", [64, 128, 256])
@pytest.mark.parametrize("n_blocks", [100])
@pytest.mark.parametrize("block_size", [4, 8, 16, 32])
@pytest.mark.parametrize("block_size", [4, 8, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [-1, 40])
# @pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 50.0])
@pytest.mark.parametrize("sliding_window", [-1, 50])
@pytest.mark.parametrize("alibi", [False, True])
@torch.inference_mode
def test_flashinfer_varlen_masked_self_attention(
seq_lens: List[Tuple[int, int]],
Expand All @@ -26,6 +26,7 @@ def test_flashinfer_varlen_masked_self_attention(
block_size: int,
logits_soft_cap: float,
sliding_window: int,
alibi: bool,
) -> None:
torch.set_default_device("cuda")

Expand Down Expand Up @@ -89,7 +90,8 @@ def test_flashinfer_varlen_masked_self_attention(
empty_q_data,
)

pos_encoding_mode = 0
alibi_slopes = torch.randn(n_heads, dtype=torch.float32) if alibi else None

output = wrapper.run(
query,
qo_indptr,
Expand All @@ -98,10 +100,10 @@ def test_flashinfer_varlen_masked_self_attention(
value_cache,
paged_kv_indptr,
paged_kv_indices,
pos_encoding_mode,
sliding_window,
logits_soft_cap,
sm_scale,
alibi_slopes,
)

ref_output = varlen_masked_self_attention(
Expand All @@ -114,9 +116,13 @@ def test_flashinfer_varlen_masked_self_attention(
sm_scale=sm_scale,
logits_soft_cap=logits_soft_cap,
sliding_window=sliding_window,
alibi_slopes=alibi_slopes,
)

torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
if alibi and dtype == torch.bfloat16:
torch.testing.assert_close(output, ref_output, atol=1e-1, rtol=1e-2)
else:
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
Expand Down
19 changes: 17 additions & 2 deletions tests/kernels/attention/ref_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import torch

Expand All @@ -7,6 +7,7 @@ def masked_self_attention(
query: torch.Tensor, # [q_len, n_heads, head_dim]
key: torch.Tensor, # [kv_len, n_heads, head_dim]
value: torch.Tensor, # [kv_len, n_heads, head_dim]
alibi_bias: Optional[torch.Tensor], # [n_heads, 1, kv_len]
mask: torch.Tensor, # [n_heads, q_len, kv_len]
sm_scale: float,
logits_soft_cap: float,
Expand All @@ -20,6 +21,10 @@ def masked_self_attention(
# apply soft_cap
if logits_soft_cap > 0.0:
scores = torch.tanh(scores / logits_soft_cap) * logits_soft_cap

# apply alibi bias
if alibi_bias is not None:
scores += alibi_bias

# apply mask
scores.masked_fill_(mask == 0, float("-inf"))
Expand All @@ -40,6 +45,7 @@ def varlen_masked_self_attention(
sm_scale: float,
logits_soft_cap: float = 0.0,
sliding_window: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert key_cache.shape == value_cache.shape

Expand Down Expand Up @@ -84,12 +90,21 @@ def varlen_masked_self_attention(
# returns the lower triangular part of a matrix
mask = mask.tril(diagonal=kv_len - q_len).to(query)

# TODO: add alibi bias support
# calculate alibi attention bias
alibi_bias = None
if alibi_slopes is not None:
assert alibi_slopes.shape == (n_heads,)
# since it's causal mask, we can just use [0, 1, ...,, kv_len)
distance = torch.arange(kv_len, dtype=torch.float32)
# [n_heads, 1, kv_len]
alibi_bias = distance.view(1, 1, -1) * alibi_slopes.view(n_heads, 1, 1)


out = masked_self_attention(
query=q,
key=k,
value=v,
alibi_bias=alibi_bias,
mask=mask,
sm_scale=sm_scale,
logits_soft_cap=logits_soft_cap,
Expand Down

0 comments on commit 3460e9d

Please sign in to comment.