From 000f0612274dcd2ae8c4a45dd2cb709aa09daa60 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 09:04:40 +0800 Subject: [PATCH 01/12] Refactor prepare_model_input --- vllm/attention/backends/ipex_attn.py | 21 +-- vllm/worker/xpu_model_runner.py | 216 ++++++++++++++++++++++++--- 2 files changed, 197 insertions(+), 40 deletions(-) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 23b3e6ec8c0cc..cef087396d565 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -67,6 +67,8 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] seqlen_q: Optional[torch.Tensor] max_seqlen: Optional[int] + query_start_loc: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] def __post_init__(self): # Set during the execution of the first attention op. @@ -265,25 +267,6 @@ def forward( att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks - # output = torch.empty( - # (num_tokens, self.num_heads, self.head_size), - # dtype=query.dtype, - # device=query.device) - # ipex_ops.varlen_attention(query, - # key, - # value, - # output, - # attn_metadata.seqlen_q, - # attn_metadata.seqlen_q, - # attn_metadata.max_seqlen, - # attn_metadata.max_seqlen, - # pdropout=0.0, - # softmax_scale=self.scale, - # zero_tensors=False, - # is_causal=True, - # return_softmax=False, - # gen_=None) - output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype, device=query.device) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 025449cfe4853..d70c2be213e65 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, Mapping import torch @@ -26,6 +27,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm.attention.backends.utils import is_block_tables_empty if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -237,35 +239,207 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + + seq_lens: List[int] = [] + # Prefill's seq_len: query_length + chunked_prefill length + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + # One for each sequence, physical blocks + block_tables: List[List[int]] = [] + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + + if len(seq_group_metadata_list) == 0: + return None + + assert self.sliding_window is None, "TODO: support sliding window later" + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + # is_prompt indicates that it is still in prompt states + # TODO: remove this is_prompt + is_prompt = seq_group_metadata.is_prompt + + # Iterate over all the seqs in the seq_group + for seq_id in seq_ids: + # Check for prefix caching + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None or computed_block_nums == [])): + raise RuntimeError("chunked prefill cannot be used with prefix caching") + seq_data = seq_group_metadata.seq_data[seq_id] + # Context_len: how many tokens that have been computed + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + context_len = seq_data.get_len() - 1 + + # Get tokens for this sequence + # For prefill, the seq_len will be the second one. + # For decoding, the seq_len will be the first one. + seq_len = min(seq_data.get_len(), context_len + seq_group_metadata.token_chunk_size) + + if is_prompt: + tokens = seq_data.get_token_ids()[context_len: seq_len] + else: + # Last token + tokens = [seq_data.get_last_token_id()] + + # FIXME: add prefix caching + if (self.scheduler_config.chunked_prefill_enabled or not is_prompt): + # Chunked prefill or decoding + # For chunked prefill, the block tables may not be None + if seq_group_metadata.block_tables is not None: + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = [] + else: + # Prefill without chunked prefill + block_table = [] + block_tables.append(block_table) + # Total seq_lens + seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, "Wrong query length in decoding" + num_decode_tokens += 1 + decode_seq_lens.append(seq_len) + if is_block_tables_empty(seq_group_metadata.block_tables): + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # seq_id: List[int] + block_table = seq_group_metadata.block_tables[seq_id] + + # TODO: add sliding window + for i in range(context_len, seq_len): + # if i < start_idx: + # slot_mapping.append(_PAD_SLOT_ID) + # continue + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + # slot_mapping is when we flatteren the blocks, and see which block it is located + # block_table is a logical -> to physical transition... + # i // block_size is the logical block number + slot_mapping.append(block_number * self.block_size + block_offset) + max_query_len = max(query_lens) + max_decode_seq_len = max(decode_seq_lens, default=0) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + # What is the usage of this seq_start_loc? + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + # Generate attn_metadata + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + attn_metadata = self.attn_backend.make_metadata( + # FIXME: Later maybe we can get rid of this parameter + is_prompt=is_prompt, #1 + num_prefills=num_prefills, # 6 + slot_mapping=slot_mapping_tensor, # 2 + num_prefill_tokens=num_prefill_tokens, # 7 + num_decode_tokens=num_decode_tokens, # 8 + seq_lens=seq_lens_tensor, # 3 + seqlen_q=seqlen_q, # 4 + # max_seqlen=max_seqlen, # 5 + max_seqlen=max(query_lens), + seq_lens_tensor=seq_lens_tensor, # 9 + # max_query_len=max_query_len, + max_decode_seq_len=max_decode_seq_len, # 10 + query_start_loc=query_start_loc, + # seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables if (self.scheduler_config.chunked_prefill_enabled or not is_prompt) else torch.tensor([], device=self.device, dtype=torch.int) # 11 + ) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, # subquery_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill # just use seq_lens instead. - seq_lens, + query_lens, self.device, pin_memory=False) - - return ModelInputForXPU(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - virtual_engine=virtual_engine) + return ModelInputForXPU( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=None, + virtual_engine=virtual_engine + ) def _prepare_decode( self, From b6b1418aae7f579294a171a6a0c3fdd9536fca2d Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 10:03:03 +0800 Subject: [PATCH 02/12] Enable ipex-attn.py backend with new model_input --- vllm/attention/backends/ipex_attn.py | 141 +++++++++++++++++++-------- 1 file changed, 103 insertions(+), 38 deletions(-) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index cef087396d565..433fd44dd28d5 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -70,6 +70,10 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): query_start_loc: Optional[torch.Tensor] context_lens: Optional[torch.Tensor] + + _cached_prefill_metadata: Optional["IpexAttnMetadata"] = None + _cached_decode_metadata: Optional["IpexAttnMetadata"] = None + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -80,21 +84,65 @@ def __post_init__(self): @property def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self + if self.num_prefills == 0: + return None - return None + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = IpexAttnMetadata( + is_prompt=self.is_prompt, + seqlen_q=self.seqlen_q, + max_seqlen=self.max_seqlen, + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + # seq_start_loc=None, + context_lens=self.context_lens[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + ) + return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + if self.num_decode_tokens == 0: return None - return self + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = IpexAttnMetadata( + is_prompt=self.is_prompt, + seqlen_q=self.seqlen_q, + max_seqlen=self.max_seqlen, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + # max_query_len=None, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + # seq_start_loc=None, + context_lens=None, + block_tables=self.block_tables[self.num_prefills:], + ) + return self._cached_decode_metadata from torch.nn.functional import scaled_dot_product_attention @@ -246,37 +294,52 @@ def forward( v_scale, ) - if attn_metadata.is_prompt: - assert attn_metadata.seq_lens is not None - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + # New added code-segment + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert query.shape[0] == num_prefill_tokens + num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.seq_lens is not None + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + prefill_meta.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, + prefill_meta.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.seq_lens) - attn_metadata.attn_bias = att_masks - - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype, device=query.device) + att_masks = [None] * len(prefill_meta.seq_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - for seq_len, mask in zip(attn_metadata.seq_lens, - attn_metadata.attn_bias): + for seq_len, mask in zip(prefill_meta.seq_lens, + prefill_meta.attn_bias): end = start + seq_len if use_sdp_causal(self.head_size, query): import xe_addons @@ -301,16 +364,17 @@ def forward( output[start:end, :, :] = sub_out start = end else: - # prefix-enabled attention - raise RuntimeError( - "IPEX backend doesn't support prefix decoding.") + # TODO: add chunked prefill feature here... + pass + - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - max_seq_len = attn_metadata.max_decode_seq_len - output = torch.empty_like(query) + max_seq_len = decode_meta.max_decode_seq_len + out = torch.empty_like(decode_query) block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape + # print(f"In decoding, the shape is:{decode_query.shape}") + num_seqs, num_heads, head_size = decode_query.shape max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use @@ -326,14 +390,14 @@ def forward( if use_v1: # Run PagedAttention V1. ipex_ops.paged_attention_v1( - output, - query, + out, + decode_query, key_cache, value_cache, self.num_kv_heads, self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -356,7 +420,7 @@ def forward( ) max_logits = torch.empty_like(exp_sums) ipex_ops.paged_attention_v2( - output, + out, exp_sums, max_logits, tmp_output, @@ -365,8 +429,8 @@ def forward( value_cache, self.num_kv_heads, self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -374,8 +438,9 @@ def forward( k_scale, v_scale, ) + output[num_prefill_tokens:] = out - # Reshape the output tensor. + # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) From 5963dd406df50d783e04fc7a2dd0f51f572cb92f Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 11:00:27 +0800 Subject: [PATCH 03/12] Kernel added --- csrc/xpu/attention_xpu.cpp | 1074 ++++++++++++++++++++++++++++++++++++ csrc/xpu/pybind.cpp | 6 + csrc/xpu/xpu_ops.h | 16 + 3 files changed, 1096 insertions(+) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 833f46eaaf726..e0aaf38645ebe 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -4,6 +4,7 @@ #endif #include #include +#include // clang-format on #include @@ -19,6 +20,7 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) +using namespace sycl::ext::intel::esimd; template struct Float_Trait { @@ -139,6 +141,924 @@ inline float block_sum( item_ct1.get_sub_group(), sum, 0); } +// How about implement a first edition that can be used with non-chunked +// prefill requests, so that we can make sure the reference for heads is +// correct +template +void context_attention_kernel_v1( + void* query, void* key, void* value, const void* block_tables, + const float scale, const void* query_start_loc, const void* seq_lens, + const void* context_lens, const int block_size, + const int x, // x in kv_cache + void* out, // output + const int block_table_stride_batch, const int block_table_stride_seq, + const int query_stride_bs, const int query_stride_head, + const int query_stride_dim, const int k_cache_stride_tokens, + const int k_cache_stride_head, const int k_cache_stride_dim, + const int k_cache_stride_block_size, const int k_cache_stride_x, + const int v_cache_stride_tokens, const int v_cache_stride_head, + const int v_cache_stride_dim, const int v_cache_stride_block_size, + const int out_stride_tokens, const int out_stride_head, + const int num_queries_per_kv, const int max_input_length, + const int batch_size, const int num_heads) { + static_assert(GS * HD * sizeof(scalar_t) * 2 < 64 * 1024); + + const size_t key_slm_offset = 0; + const size_t value_slm_offset = GS * HD * sizeof(scalar_t); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + // Get the maximum seq_lens + sycl::range<3> global_size(batch_size, num_heads, + (max_input_length + GS - 1) / GS * GS); + sycl::range<3> local_size(1, 1, GS); + + auto cgf = [&](sycl::handler& handle) { + handle.parallel_for( + sycl::nd_range<3>(global_size, local_size), + [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { + slm_init(); + + const size_t bsz_idx = item.get_global_id(0); + const size_t head_idx = item.get_global_id(1); + // Assuming we have 32 query head and 8 kv_heads. Then + // num_queries_per_group should be 4 For head_idx 13, then + // kv_head_idx = 13 / 4 = 3, which is correct + const size_t kv_head_idx = head_idx / num_queries_per_kv; + const int32_t seq_idx = item.get_global_id(2); + const size_t gid = item.get_group(2); + const size_t tid = item.get_local_id(2); + + // const int64_t * seq_len = (const int64_t *) seq_lens; + const int32_t* seq_len = (const int32_t*)seq_lens; + int32_t seq_bound = seq_len[bsz_idx]; + + const int32_t* query_loc = (const int32_t*)query_start_loc; + // There is a possibility that the current token index pass + // over the seq_len, therefore: token_idx is the position in + // the query + int32_t token_idx = + query_loc[bsz_idx] + std::min(seq_idx, seq_bound - 1); + + const int32_t* context_len_pointer = (const int32_t*)context_lens; + + const int* block_tables_ptr = (const int*)block_tables; + const int* block_table = + block_tables_ptr + bsz_idx * block_table_stride_batch; + // I guess this context_len should be 0... + const int32_t context_len = context_len_pointer[bsz_idx]; + + // Position in the sequence + // context + seq_idx + // const int32_t token_position = + // context_len + std::min(seq_idx, seq_bound - 1); + const int32_t token_position = context_len + seq_idx; + + // static const CONSTANT char FMT[] = + // "Invoke target function...\n "; + + // sycl::ext::oneapi::experimental::printf(FMT); + // static const CONSTANT char FMT[] = + // "GroupID = %6d bsz_idx = %6d seq_len = %6d seq_idx = + // %6d" "local_id = " + // "%6d " + // "token_idx = %6d " + // "context_len = %6d " + // "v_cache_stride_head_dim = %6d " + // "token_position = %6d\n"; + // sycl::ext::oneapi::experimental::printf( + // FMT, gid, bsz_idx, seq_bound, seq_idx, tid, + // token_idx, context_len, v_cache_stride_dim, + // token_position); + + const scalar_t* query_head = (const scalar_t*)query + + token_idx * query_stride_bs + + head_idx * query_stride_head; + // Target output + scalar_t* out_head = + (scalar_t*)out + + (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + + head_idx * out_stride_head; + // The indexing for key_head will be wired... + // Assuming context length is in n * GS + offset_part, now + // we are handling the n * GS part + int32_t context_groups = context_len / GS; + // TODO: consider context groups later + // TODO: consider n*GS part later + + // Each token load its query_row + simd query_row = + block_load(query_head) * scale; + simd accv = 0; + simd softmaxv = 0; + scalar_t max_attn = -sycl::detail::max_v(); + + // ############################ Handle n * GS context part + // ###################### + int32_t n = context_len / GS; + int32_t context_offset = context_len % GS; + + // static const CONSTANT char FMT[] = + // "GroupID = %2d seq_len = %d seq_idx = %d token_idx = + // %d token_position = %d " "context_len = %d n = %d + // context_offset = %d\n"; + // sycl::ext::oneapi::experimental::printf( + // FMT, gid, seq_bound, seq_idx, + // token_idx, token_position, context_len, n, + // context_offset); + + // TODO: this target_key_position has problems + for (int32_t group = 0; group < n; ++group) { + size_t target_key_position = group * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + // Seems to have an error here + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + barrier(); + + // # Now begins to calculate attention... + // Calculate QK^T for this group... + simd attnv; +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + attnv[r] = attn; + } + scalar_t new_max_attn = + std::max(hmax(attnv), max_attn); + scalar_t attn_exp = exp(max_attn - new_max_attn); + accv = accv * attn_exp; + softmaxv = softmaxv * attn_exp; + max_attn = new_max_attn; + const simd attn_expv = exp(attnv - max_attn); +#pragma unorll + for (size_t r = 0; r < GS; ++r) { + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + accv += value_row * attn_expv[r]; + } + softmaxv += attn_expv; + barrier(); + } + + // ########################### End for handling context n * + // GS part ########### + + // ############################# Handle n * GS + // ############################ + for (size_t group = 0; group < gid; ++group) { + // 1. begins to load each position's key and value + size_t target_key_position = group * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + // Seems to have an error here + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + barrier(); + simd attnv; +#pragma unroll + for (size_t r = 0; r < GS; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + attnv[r] = attn; + } + + scalar_t new_max_attn = + std::max(hmax(attnv), max_attn); + scalar_t attn_exp = exp(max_attn - new_max_attn); + accv = accv * attn_exp; + + softmaxv = softmaxv * attn_exp; + max_attn = new_max_attn; + const simd attn_expv = exp(attnv - max_attn); +#pragma unorll + for (size_t r = 0; r < GS; ++r) { + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + accv += value_row * attn_expv[r]; + } + softmaxv += attn_expv; + barrier(); + } + + // ############## End of handle n * GS part + // ################## + + // ################ Handle offset part #################### + scalar_t softmax = + sycl::ext::intel::esimd::detail::sum( + softmaxv); + + // ############## handle context offset ############ + if (tid < context_offset) { + size_t target_key_position = n * GS + tid; + int which_block = target_key_position / block_size; + int which_slot = target_key_position % block_size; + + int physical_block_number = block_table[which_block]; + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + // Seems to have an error here + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + } + + barrier(); + + // FIXME: For all the tokens, we will need to calculate the + // qks For tokens that are valid... if (tid < + // context_offset) { + if (token_position < seq_bound) { + // This could be an error place + for (size_t r = 0; r < context_offset; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + if (attn <= max_attn) { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(attn - max_attn); + accv += value_row * attn_exp; + softmax += attn_exp; + } else { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(max_attn - attn); + accv = accv * attn_exp + value_row; + softmax = softmax * attn_exp + 1; + max_attn = attn; + } + } + } + barrier(); + // ############## handle seq offset ################# + // TODO: check if this part has problem or not... + // if (seq_idx < seq_bound) { + // const int64_t which_block = + // static_cast(token_position / + // block_size); + // const int64_t which_slot = + // static_cast(token_position % + // block_size); + + // // TODO: we might need to cast this to int64_t to + // avoid + // // overflow... + // const int64_t physical_block_number = + // static_cast(block_table[which_block]); + + // const scalar_t* key_head = + // (const scalar_t*)key + + // physical_block_number * k_cache_stride_tokens + + // kv_head_idx * k_cache_stride_head + + // which_slot * k_cache_stride_block_size; + + // // Let's do a loop to load the data + // // 0 to 7 + // for (int i = 0; i < HD / x; i++) { + // // Load 8 elements + // simd key_row = block_load( + // key_head + i * k_cache_stride_dim); + // slm_block_store(key_slm_offset + + // tid * HD * sizeof(scalar_t) + + // 8 * i * sizeof(scalar_t), + // key_row); + // } + + // // v_cache in shape [num_blocks, num_kv_heads, + // head_size, + // // block_size] + // const scalar_t* value_head = + // (const scalar_t*)value + + // physical_block_number * v_cache_stride_tokens + + // kv_head_idx * v_cache_stride_head + which_slot; + // for (int i = 0; i < HD; i++) { + // // Seems to have an error here + // scalar_t temp_value = + // value_head[i * v_cache_stride_dim]; + // slm_scalar_store( + // value_slm_offset + tid * HD * + // sizeof(scalar_t) + + // i * sizeof(scalar_t), + // temp_value); + // } + // } + // barrier(); + + // if (seq_idx < seq_bound) { + // // handle last a few of tokens + // for (size_t r = 0; r <= tid; ++r) { + // simd key_row = + // slm_block_load( + // key_slm_offset + r * HD * + // sizeof(scalar_t)); + // simd value_row = + // slm_block_load( + // value_slm_offset + r * HD * + // sizeof(scalar_t)); + // scalar_t attn = + // sycl::ext::intel::esimd::detail::sum< + // scalar_t, scalar_t, HD>(query_row * key_row); + // if (attn <= max_attn) { + // scalar_t attn_exp = + // sycl::ext::intel::esimd::exp(attn - + // max_attn); + // accv += value_row * attn_exp; + // softmax += attn_exp; + // } else { + // scalar_t attn_exp = + // sycl::ext::intel::esimd::exp(max_attn - + // attn); + // accv = accv * attn_exp + value_row; + // softmax = softmax * attn_exp + 1; + // max_attn = attn; + // } + // } + + // if (softmax > 0) { + // simd result = accv / softmax; + // block_store(out_head, result); + // } else { + // simd result = 0; + // block_store(out_head, result); + // } + // } + if (token_position < seq_bound) { + const int64_t which_block = + static_cast(token_position / block_size); + const int64_t which_slot = + static_cast(token_position % block_size); + + // TODO: we might need to cast this to int64_t to avoid + // overflow... + const int64_t physical_block_number = + static_cast(block_table[which_block]); + + const scalar_t* key_head = + (const scalar_t*)key + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + which_slot * k_cache_stride_block_size; + + // Let's do a loop to load the data + // 0 to 7 + for (int i = 0; i < HD / x; i++) { + // Load 8 elements + simd key_row = + block_load(key_head + i * k_cache_stride_dim); + slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + + 8 * i * sizeof(scalar_t), + key_row); + } + + // v_cache in shape [num_blocks, num_kv_heads, + // head_size, block_size] + const scalar_t* value_head = + (const scalar_t*)value + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head + which_slot; + for (int i = 0; i < HD; i++) { + // Seems to have an error here + scalar_t temp_value = value_head[i * v_cache_stride_dim]; + slm_scalar_store(value_slm_offset + + tid * HD * sizeof(scalar_t) + + i * sizeof(scalar_t), + temp_value); + } + } + barrier(); + + if (token_position < seq_bound) { + // handle last a few of tokens + for (size_t r = 0; r <= tid; ++r) { + simd key_row = slm_block_load( + key_slm_offset + r * HD * sizeof(scalar_t)); + simd value_row = slm_block_load( + value_slm_offset + r * HD * sizeof(scalar_t)); + scalar_t attn = + sycl::ext::intel::esimd::detail::sum( + query_row * key_row); + if (attn <= max_attn) { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(attn - max_attn); + accv += value_row * attn_exp; + softmax += attn_exp; + } else { + scalar_t attn_exp = + sycl::ext::intel::esimd::exp(max_attn - attn); + accv = accv * attn_exp + value_row; + softmax = softmax * attn_exp + 1; + max_attn = attn; + } + } + + if (softmax > 0) { + simd result = accv / softmax; + block_store(out_head, result); + } else { + simd result = 0; + block_store(out_head, result); + } + } + // ######## Ending of handling seq offset ########## + }); + }; + queue.submit(cgf); +} + +// How about implement a first edition that can be used with non-chunked prefill +// requests, so that we can make sure the reference for heads is correct +template +void context_attention_kernel_v2( + void* query, void* key, void* value, const void* block_tables, + const float scale, const void* query_start_loc, const void* seq_lens, + const void* context_lens, const int block_size, + const int x, // x in kv_cache + void* out, // output + const int block_table_stride_batch, const int block_table_stride_seq, + const int query_stride_bs, const int query_stride_head, + const int query_stride_dim, const int k_cache_stride_tokens, + const int k_cache_stride_head, const int k_cache_stride_dim, + const int k_cache_stride_block_size, const int k_cache_stride_x, + const int v_cache_stride_tokens, const int v_cache_stride_head, + const int v_cache_stride_dim, const int v_cache_stride_block_size, + const int out_stride_tokens, const int out_stride_head, + const int num_queries_per_kv, const int max_input_length, + const int batch_size, const int num_heads, const int num_tokens, + const int max_context_len) { + constexpr int BLOCK_SIZE = 16; + constexpr int NUM_THREADS = 128; + // Each wrap handles one context block, therefore, each thread_group_size is + // this. + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + // Each query, and key thread_group loads 16 bytes + // Assume TGS=4 then 16 / 4 / sizeof(half) = 2 + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using sycl_t = vllm::xpu::SyclTypeTrait::Type; + using Q_Vec = typename Vec::Type; + + // Assuming HD = 128, TGS = 2, then 128 / 2 / 2 = 32 + int num_vecs_per_thread = HD / THREAD_GROUP_SIZE / VEC_SIZE; + sycl_t* out_p = reinterpret_cast(out); + sycl_t* query_ptr = reinterpret_cast(query); + sycl_t* key_cache_ptr = reinterpret_cast(key); + sycl_t* value_cache_ptr = reinterpret_cast(value); + const int* query_loc_ptr = reinterpret_cast(query_start_loc); + const int* block_tables_ptr = reinterpret_cast(block_tables); + const int* context_lens_ptr = reinterpret_cast(context_lens); + const int* seq_lens_ptr = reinterpret_cast(seq_lens); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = + DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float); + // Python-side check in + // vllm.worker.worker._check_if_can_support_max_seq_len Keep that in + // sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + // WARN: we have changed this... + sycl::range<3> grid(batch_size, num_heads, max_input_length); + // One work-group that is executing on the device + sycl::range<3> block(1, 1, NUM_THREADS); + sycl::queue& queue = vllm::xpu::vllmGetQueue(); + + auto cgf = [&](sycl::handler& handle) { + // sycl::stream output_stream(128000, 128, handle); + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(shared_mem_size), handle); + sycl::local_accessor q_vecs_acc_ct1( + sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), handle); + sycl::local_accessor red_smem_acc_ct1( + sycl::range<1>(2 * NUM_WARPS), handle); + + handle.parallel_for( + // (batch_size, num_heads, max_input_length * 128) (1, 1, 128) + // Each workgroup handles one token + sycl::nd_range<3>(grid * block, block), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + // FIXME: change this... + // const int bsz_idx = item_ct1.get_global_id(0); + const int bsz_idx = item_ct1.get_group(0); + const int seq_idx = item_ct1.get_group(2); + constexpr bool USE_PARTITIONING = false; + const int context_len = context_lens_ptr[bsz_idx] + seq_idx; + const int seq_len = seq_lens_ptr[bsz_idx]; + uint8_t* dpct_local = dpct_local_acc_ct1.get_pointer(); + Q_Vec* q_vecs = q_vecs_acc_ct1.get_pointer(); + float* red_smem = red_smem_acc_ct1.get_pointer(); + + // output_stream << "Original context_len: " << + // context_lens_ptr[bsz_idx] << sycl::endl; output_stream << + // "Batch_idx: " << bsz_idx << " Seq_idx: " << seq_idx + // << " Context_len: " << context_len << " Original context_len: " + // << context_lens_ptr[bsz_idx] << " Seq_len: " << seq_len + // << " Max input length: " << max_input_length + // << sycl::endl; + // FIXME: chang this to >= + // Assuming seq_len is 5, then seq_idx should be 0, 1, 2, 3, 4, 5 + // Shall the query token attend to itself? + if (context_len >= seq_len) { + return; + } + + const int num_context_blocks = + DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = num_context_blocks; + + const int start_block_idx = 0; + // TODO: remove this + const int end_block_idx = + MIN(start_block_idx + num_context_blocks, num_context_blocks); + + const int num_blocks = end_block_idx - start_block_idx; + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + // THREAD_GROUP_SIZE equals to 2 + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + // 128 / 2 = 64 THREAD GROUPS -> 4 warps, 16 thread group per + // warp + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / + THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = item_ct1.get_local_id(2); + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + const int head_idx = item_ct1.get_group(1); + const int num_heads = item_ct1.get_group_range(1); + const int kv_head_idx = head_idx / num_queries_per_kv; + // TODO: consider alibi_slope later + constexpr int NUM_ELEMS_PER_THREAD = HD / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + // num_tokens, num_heads, HD + // TODO: fix this + // const sycl_t* q_ptr = + // query_ptr + seq_idx * query_stride_bs + head_idx * HD; + const sycl_t* q_ptr = + query_ptr + (query_loc_ptr[bsz_idx] + seq_idx) * query_stride_bs + + head_idx * HD; + +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + // Loaded q_vecs + item_ct1.barrier(sycl::access::fence_space::local_space); + auto shared_mem = (char*)dpct_local; + float* logits = reinterpret_cast(shared_mem); + constexpr int x = 16 / sizeof(sycl_t); + float qk_max = -FLT_MAX; + // TODO: check if block_table include everything? + const int* block_table = + block_tables_ptr + bsz_idx * block_table_stride_batch; + + // Loading key + for (int block_idx = start_block_idx + warp_idx; + block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = + block_idx * BLOCK_SIZE + physical_block_offset; + + Q_Vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const sycl_t* k_ptr = + key_cache_ptr + + physical_block_number * k_cache_stride_tokens + + kv_head_idx * k_cache_stride_head + + physical_block_offset * x; + + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the + // same thread group. Q_Vec_t + // q_vec_[NUM_VECS_PER_THREAD] = q_vecs + + // thread_group_offset * THREAD_GROUP_SIZE; + float qk = scale * + Qk_dot::template dot< + Q_Vec, NUM_VECS_PER_THREAD>( + q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, + k_vecs, item_ct1); + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the + // masked logits. + const bool mask = token_idx >= context_len; + // TODO: uncomment + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); + } + } + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + /* + DPCT1096:38: The right-most dimension of the work-group used + in the SYCL kernel that calls this function may be less than + "32". The function "dpct::permute_sub_group_by_xor" may + return an unexpected result on the CPU device. Modify the + size of the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:39: The right-most dimension of the work-group used + in the SYCL kernel that calls this function may be less than + "32". The function "dpct::permute_sub_group_by_xor" may + return an unexpected result on the CPU device. Modify the + size of the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + qk_max = + sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), qk_max, mask)); + } + qk_max = + dpct::select_from_sub_group(item_ct1.get_sub_group(), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = sycl::exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = + block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); + // Compute softmax. + const float inv_sum = 1.f / (exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + constexpr int V_VEC_SIZE = MIN(16 / sizeof(sycl_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HD, NUM_ROWS_PER_ITER); + // NOTE(woosuk): We use FP32 for the accumulator for better + // accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + sycl_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; + block_idx < end_block_idx; block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. + // However, we cast it to int64 because int32 can lead to + // overflow when this variable is multiplied by large + // numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = + (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = + block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + vllm::from_float( + logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const sycl_t* v_ptr = + value_cache_ptr + + physical_block_number * v_cache_stride_tokens + + kv_head_idx * v_cache_stride_head; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens + // that are out of the context, we should + // explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + sycl_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += vllm::dot(logits_vec, v_vec); + } + } + } + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + /* + DPCT1096:41: The right-most dimension of the work-group + used in the SYCL kernel that calls this function may be + less than "32". The function + "dpct::permute_sub_group_by_xor" may return an + unexpected result on the CPU device. Modify the size of + the work-group to ensure that the value of the + right-most dimension is a multiple of "32". + */ + acc += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), + acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory + // space for logits is reused for the output. + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + // Write the final output. + if (warp_idx == 0) { + sycl_t* out_ptr = + out_p + (query_loc_ptr[bsz_idx] + seq_idx) * out_stride_tokens + + head_idx * out_stride_head; + + // sycl_t* out_ptr = + // out_p + + // seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + // head_idx * max_num_partitions * HEAD_SIZE + + // partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = + lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { + vllm::from_float(*(out_ptr + row_idx), accs[i]); + } + } + } + }); + // Each thread_group handles one token + }; + queue.submit(cgf); +} + template < typename scalar_t, typename Q_Vec_t, @@ -1251,4 +2171,158 @@ void paged_attention_v2( query.scalar_type(), "paged_attention_xpu_v2_impl", [&] { CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); }); +} + +torch::Tensor context_attention_forward_v2( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length) { + // TODO: Dispatch to different query.scalar_type() if needed. + int64_t num_tokens = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_dim = query.size(2); + int64_t batch_size = seq_lens.size(0); + int num_kv_heads = value.size(1); + + int key_dimension = key.dim(); + auto output = at::empty({query.size(0), query.size(1), query.size(2)}, + at::device(query.device()).dtype(query.dtype())); + + // key should be in shape: + // 1. [num_tokens, num_kv_head, head_dim] + assert(key_dimension == 3 or key_dimension == 5); + assert(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type()); + assert(head_dim == 128); + assert(query.scalar_type() == at::ScalarType::Half); + + int query_stride_token = query.stride(0); + int query_stride_head = query.stride(1); + int query_stride_dim = query.stride(2); + const float attn_scale = 1 / std::sqrt((float)head_dim); + + assert(num_heads % num_kv_heads == 0); + int num_queries_per_kv = num_heads / num_kv_heads; + + + // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) + // value: [num_blocks, num_kv_heads, head_size, block_dim] + int block_size = value.size(3); + int x = key.size(4); + int block_table_stride_bsz = block_tables.stride(0); + int block_table_stride_seq = block_tables.stride(1); + int k_cache_stride_token = key.stride(0); + int k_cache_stride_head = key.stride(1); + int k_cache_stride_head_dim = key.stride(2); + int k_cache_stride_block = key.stride(3); + int k_cache_stride_x = key.stride(4); + + int v_cache_stride_token = value.stride(0); + int v_cache_stride_head = value.stride(1); + int v_cache_stride_head_dim = value.stride(2); + int v_cache_stride_block = value.stride(3); + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + + // vllm::context_attention_kernel( + // query.data_ptr(), key.data_ptr(), value.data_ptr(), + // block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + // seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + // output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + // query_stride_token, query_stride_head, query_stride_dim, + // k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + // k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + // v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + // output.stride(0), output.stride(1), num_queries_per_kv, + // max_input_length, batch_size, num_heads); + return output; +} + +torch::Tensor context_attention_forward_v1( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length) { + // TODO: Dispatch to different query.scalar_type() if needed. + int64_t num_tokens = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_dim = query.size(2); + int64_t batch_size = seq_lens.size(0); + int num_kv_heads = value.size(1); + + int key_dimension = key.dim(); + auto output = at::empty({query.size(0), query.size(1), query.size(2)}, + at::device(query.device()).dtype(query.dtype())); + + // key should be in shape: + // 1. [num_tokens, num_kv_head, head_dim] + assert(key_dimension == 3 or key_dimension == 5); + assert(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type()); + assert(head_dim == 128); + assert(query.scalar_type() == at::ScalarType::Half); + + int query_stride_token = query.stride(0); + int query_stride_head = query.stride(1); + int query_stride_dim = query.stride(2); + const float attn_scale = 1 / std::sqrt((float)head_dim); + + assert(num_heads % num_kv_heads == 0); + int num_queries_per_kv = num_heads / num_kv_heads; + + // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) + // value: [num_blocks, num_kv_heads, head_size, block_dim] + int block_size = value.size(3); + int x = key.size(4); + int block_table_stride_bsz = block_tables.stride(0); + int block_table_stride_seq = block_tables.stride(1); + int k_cache_stride_token = key.stride(0); + int k_cache_stride_head = key.stride(1); + int k_cache_stride_head_dim = key.stride(2); + int k_cache_stride_block = key.stride(3); + int k_cache_stride_x = key.stride(4); + + int v_cache_stride_token = value.stride(0); + int v_cache_stride_head = value.stride(1); + int v_cache_stride_head_dim = value.stride(2); + int v_cache_stride_block = value.stride(3); + // vllm::context_attention_kernel_v2( + // query.data_ptr(), key.data_ptr(), value.data_ptr(), + // block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + // seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + // output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + // query_stride_token, query_stride_head, query_stride_dim, + // k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + // k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + // v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + // output.stride(0), output.stride(1), num_queries_per_kv, max_input_length, + // batch_size, num_heads, query.size(0), max_context_length); + + vllm::context_attention_kernel_v1( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads); + return output; } \ No newline at end of file diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp index 4e7f2fa6bd80f..e42ae45c6a50c 100644 --- a/csrc/xpu/pybind.cpp +++ b/csrc/xpu/pybind.cpp @@ -75,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "awq_dequantize", &awq_dequantize, "dequant method for awq"); + + ops.def("context_attention_forward_v1", &context_attention_forward_v1, + "Context attention forward_v1"); + + ops.def("context_attention_forward_v2", &context_attention_forward_v2, + "Context attention forward_v2"); } diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h index 6125b19ac80b5..db7ceef1da343 100644 --- a/csrc/xpu/xpu_ops.h +++ b/csrc/xpu/xpu_ops.h @@ -40,6 +40,22 @@ void paged_attention_v2( int max_context_len, const c10::optional &alibi_slopes, const std::string& kv_cache_dtype, const float kv_scale); +torch::Tensor context_attention_forward_v1( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length); + +torch::Tensor context_attention_forward_v2( + torch::Tensor query, // [num_tokens, num_kv_head, head_dim] + torch::Tensor key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor value, // [num_tokens, num_kv_heads * head_size] + torch::Tensor block_tables, torch::Tensor query_start_loc, + torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, + int max_context_length); + void copy_blocks( std::vector &key_caches, std::vector &value_caches, From cd9ec729e6bd24d3cc2ce6c107d50c08177d382a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 14:08:56 +0800 Subject: [PATCH 04/12] Enable chunked-prefill branch --- vllm/attention/backends/ipex_attn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 433fd44dd28d5..e2f4f0fecb88c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -364,8 +364,14 @@ def forward( output[start:end, :, :] = sub_out start = end else: - # TODO: add chunked prefill feature here... - pass + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + import vllm._C.ops + out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: From 9ec2a46d5971e3ec0df86afcc03c9437487f3a1b Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 17:39:36 +0800 Subject: [PATCH 05/12] fix minor error in decoding path --- vllm/attention/backends/ipex_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index e2f4f0fecb88c..45ebb4ef62201 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -430,7 +430,7 @@ def forward( exp_sums, max_logits, tmp_output, - query, + decode_query, key_cache, value_cache, self.num_kv_heads, From ee4500b709aeffd749f3928cb594cb294b339e4a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 25 Sep 2024 23:47:18 +0800 Subject: [PATCH 06/12] Fix generation error --- csrc/xpu/attention_xpu.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index e0aaf38645ebe..a5f1e4f9f43b6 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -723,7 +723,7 @@ void context_attention_kernel_v2( const int bsz_idx = item_ct1.get_group(0); const int seq_idx = item_ct1.get_group(2); constexpr bool USE_PARTITIONING = false; - const int context_len = context_lens_ptr[bsz_idx] + seq_idx; + int context_len = context_lens_ptr[bsz_idx] + seq_idx; const int seq_len = seq_lens_ptr[bsz_idx]; uint8_t* dpct_local = dpct_local_acc_ct1.get_pointer(); Q_Vec* q_vecs = q_vecs_acc_ct1.get_pointer(); @@ -743,6 +743,8 @@ void context_attention_kernel_v2( return; } + context_len = context_len + 1; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_blocks_per_partition = num_context_blocks; From aa65457ccdcc2077db7ce2c5df050e07271e3d43 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Thu, 26 Sep 2024 12:57:18 +0800 Subject: [PATCH 07/12] Fix long input error for v1 --- csrc/xpu/attention_xpu.cpp | 2 +- vllm/attention/backends/ipex_attn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index a5f1e4f9f43b6..0338f749f8445 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -337,7 +337,7 @@ void context_attention_kernel_v1( // ############################ for (size_t group = 0; group < gid; ++group) { // 1. begins to load each position's key and value - size_t target_key_position = group * GS + tid; + size_t target_key_position = context_len + group * GS + tid; int which_block = target_key_position / block_size; int which_slot = target_key_position % block_size; diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 45ebb4ef62201..8272a7e851408 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -369,7 +369,7 @@ def forward( value = value.repeat_interleave(self.num_queries_per_kv, dim=1) import vllm._C.ops - out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out From 4ffaa6652f6a540f173b5f842f6306649c77bbbe Mon Sep 17 00:00:00 2001 From: gc-fu Date: Mon, 7 Oct 2024 20:13:27 +0800 Subject: [PATCH 08/12] Fix context_attention_v2 --- csrc/xpu/attention_xpu.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 0338f749f8445..5d4a7bc94ccb5 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -388,7 +388,7 @@ void context_attention_kernel_v1( softmaxv = softmaxv * attn_exp; max_attn = new_max_attn; const simd attn_expv = exp(attnv - max_attn); -#pragma unorll +#pragma unroll for (size_t r = 0; r < GS; ++r) { simd value_row = slm_block_load( value_slm_offset + r * HD * sizeof(scalar_t)); @@ -691,7 +691,7 @@ void context_attention_kernel_v2( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_context_len = - DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + DIVIDE_ROUND_UP(max_context_len + 1 + max_input_length, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float); // Python-side check in @@ -848,7 +848,12 @@ void context_attention_kernel_v2( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the // masked logits. - const bool mask = token_idx >= context_len; + // TODO: consider set this to > biger position. + // Consider context_len is 512 (511 real + 1 query token) + // valid token_idx should be in the range of [0, context_len] + // And we shall set this to > + // const bool mask = token_idx >= context_len; + const bool mask = token_idx > context_len; // TODO: uncomment logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. From 6dbaf876602b11cb39db60d3cff8b31be53dcef1 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Mon, 7 Oct 2024 20:32:55 +0800 Subject: [PATCH 09/12] add environment variable to control --- vllm/attention/backends/ipex_attn.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 8272a7e851408..9e3847aa81423 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -10,6 +10,7 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) +import os _PARTITION_SIZE = 512 @@ -369,7 +370,12 @@ def forward( value = value.repeat_interleave(self.num_queries_per_kv, dim=1) import vllm._C.ops - out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + value = os.environ.get('USE_CONTEXT_V2') + if value is not None: + assert self.head_size == 128 + out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + else: + out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out From 59ec323a7ead4991cb373860ad6639343527fcf7 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 8 Oct 2024 15:32:56 +0800 Subject: [PATCH 10/12] remove unused comments --- csrc/xpu/attention_xpu.cpp | 182 +++---------------------------------- 1 file changed, 14 insertions(+), 168 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 5d4a7bc94ccb5..e1761b539c0f5 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -238,12 +238,8 @@ void context_attention_kernel_v1( (scalar_t*)out + (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + head_idx * out_stride_head; - // The indexing for key_head will be wired... - // Assuming context length is in n * GS + offset_part, now - // we are handling the n * GS part + int32_t context_groups = context_len / GS; - // TODO: consider context groups later - // TODO: consider n*GS part later // Each token load its query_row simd query_row = @@ -252,21 +248,10 @@ void context_attention_kernel_v1( simd softmaxv = 0; scalar_t max_attn = -sycl::detail::max_v(); - // ############################ Handle n * GS context part - // ###################### + // ################# Handle n * GS context part ###################### int32_t n = context_len / GS; int32_t context_offset = context_len % GS; - // static const CONSTANT char FMT[] = - // "GroupID = %2d seq_len = %d seq_idx = %d token_idx = - // %d token_position = %d " "context_len = %d n = %d - // context_offset = %d\n"; - // sycl::ext::oneapi::experimental::printf( - // FMT, gid, seq_bound, seq_idx, - // token_idx, token_position, context_len, n, - // context_offset); - - // TODO: this target_key_position has problems for (int32_t group = 0; group < n; ++group) { size_t target_key_position = group * GS + tid; int which_block = target_key_position / block_size; @@ -279,7 +264,7 @@ void context_attention_kernel_v1( kv_head_idx * k_cache_stride_head + which_slot * k_cache_stride_block_size; for (int i = 0; i < HD / x; i++) { - // Load 8 elements + // Load 8 elements, decided by x simd key_row = block_load(key_head + i * k_cache_stride_dim); slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + @@ -292,7 +277,6 @@ void context_attention_kernel_v1( physical_block_number * v_cache_stride_tokens + kv_head_idx * v_cache_stride_head + which_slot; for (int i = 0; i < HD; i++) { - // Seems to have an error here scalar_t temp_value = value_head[i * v_cache_stride_dim]; slm_scalar_store(value_slm_offset + tid * HD * sizeof(scalar_t) + @@ -301,7 +285,6 @@ void context_attention_kernel_v1( } barrier(); - // # Now begins to calculate attention... // Calculate QK^T for this group... simd attnv; #pragma unroll @@ -330,11 +313,9 @@ void context_attention_kernel_v1( barrier(); } - // ########################### End for handling context n * - // GS part ########### + // ########## End for handling context n * GS part ########### - // ############################# Handle n * GS - // ############################ + // ########## Handle n * GS ################ for (size_t group = 0; group < gid; ++group) { // 1. begins to load each position's key and value size_t target_key_position = context_len + group * GS + tid; @@ -361,7 +342,6 @@ void context_attention_kernel_v1( physical_block_number * v_cache_stride_tokens + kv_head_idx * v_cache_stride_head + which_slot; for (int i = 0; i < HD; i++) { - // Seems to have an error here scalar_t temp_value = value_head[i * v_cache_stride_dim]; slm_scalar_store(value_slm_offset + tid * HD * sizeof(scalar_t) + @@ -398,15 +378,14 @@ void context_attention_kernel_v1( barrier(); } - // ############## End of handle n * GS part - // ################## + // ######### End of handle n * GS part ########## // ################ Handle offset part #################### scalar_t softmax = sycl::ext::intel::esimd::detail::sum( softmaxv); - // ############## handle context offset ############ + // ########### handle context offset ############ if (tid < context_offset) { size_t target_key_position = n * GS + tid; int which_block = target_key_position / block_size; @@ -443,11 +422,8 @@ void context_attention_kernel_v1( barrier(); - // FIXME: For all the tokens, we will need to calculate the - // qks For tokens that are valid... if (tid < - // context_offset) { if (token_position < seq_bound) { - // This could be an error place +#pragma unroll for (size_t r = 0; r < context_offset; ++r) { simd key_row = slm_block_load( key_slm_offset + r * HD * sizeof(scalar_t)); @@ -471,107 +447,14 @@ void context_attention_kernel_v1( } } barrier(); + // ############## handle seq offset ################# - // TODO: check if this part has problem or not... - // if (seq_idx < seq_bound) { - // const int64_t which_block = - // static_cast(token_position / - // block_size); - // const int64_t which_slot = - // static_cast(token_position % - // block_size); - - // // TODO: we might need to cast this to int64_t to - // avoid - // // overflow... - // const int64_t physical_block_number = - // static_cast(block_table[which_block]); - - // const scalar_t* key_head = - // (const scalar_t*)key + - // physical_block_number * k_cache_stride_tokens + - // kv_head_idx * k_cache_stride_head + - // which_slot * k_cache_stride_block_size; - - // // Let's do a loop to load the data - // // 0 to 7 - // for (int i = 0; i < HD / x; i++) { - // // Load 8 elements - // simd key_row = block_load( - // key_head + i * k_cache_stride_dim); - // slm_block_store(key_slm_offset + - // tid * HD * sizeof(scalar_t) + - // 8 * i * sizeof(scalar_t), - // key_row); - // } - - // // v_cache in shape [num_blocks, num_kv_heads, - // head_size, - // // block_size] - // const scalar_t* value_head = - // (const scalar_t*)value + - // physical_block_number * v_cache_stride_tokens + - // kv_head_idx * v_cache_stride_head + which_slot; - // for (int i = 0; i < HD; i++) { - // // Seems to have an error here - // scalar_t temp_value = - // value_head[i * v_cache_stride_dim]; - // slm_scalar_store( - // value_slm_offset + tid * HD * - // sizeof(scalar_t) + - // i * sizeof(scalar_t), - // temp_value); - // } - // } - // barrier(); - - // if (seq_idx < seq_bound) { - // // handle last a few of tokens - // for (size_t r = 0; r <= tid; ++r) { - // simd key_row = - // slm_block_load( - // key_slm_offset + r * HD * - // sizeof(scalar_t)); - // simd value_row = - // slm_block_load( - // value_slm_offset + r * HD * - // sizeof(scalar_t)); - // scalar_t attn = - // sycl::ext::intel::esimd::detail::sum< - // scalar_t, scalar_t, HD>(query_row * key_row); - // if (attn <= max_attn) { - // scalar_t attn_exp = - // sycl::ext::intel::esimd::exp(attn - - // max_attn); - // accv += value_row * attn_exp; - // softmax += attn_exp; - // } else { - // scalar_t attn_exp = - // sycl::ext::intel::esimd::exp(max_attn - - // attn); - // accv = accv * attn_exp + value_row; - // softmax = softmax * attn_exp + 1; - // max_attn = attn; - // } - // } - - // if (softmax > 0) { - // simd result = accv / softmax; - // block_store(out_head, result); - // } else { - // simd result = 0; - // block_store(out_head, result); - // } - // } if (token_position < seq_bound) { const int64_t which_block = static_cast(token_position / block_size); const int64_t which_slot = static_cast(token_position % block_size); - // TODO: we might need to cast this to int64_t to avoid - // overflow... const int64_t physical_block_number = static_cast(block_table[which_block]); @@ -581,8 +464,6 @@ void context_attention_kernel_v1( kv_head_idx * k_cache_stride_head + which_slot * k_cache_stride_block_size; - // Let's do a loop to load the data - // 0 to 7 for (int i = 0; i < HD / x; i++) { // Load 8 elements simd key_row = @@ -592,14 +473,12 @@ void context_attention_kernel_v1( key_row); } - // v_cache in shape [num_blocks, num_kv_heads, - // head_size, block_size] + // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* value_head = (const scalar_t*)value + physical_block_number * v_cache_stride_tokens + kv_head_idx * v_cache_stride_head + which_slot; for (int i = 0; i < HD; i++) { - // Seems to have an error here scalar_t temp_value = value_head[i * v_cache_stride_dim]; slm_scalar_store(value_slm_offset + tid * HD * sizeof(scalar_t) + @@ -610,7 +489,6 @@ void context_attention_kernel_v1( barrier(); if (token_position < seq_bound) { - // handle last a few of tokens for (size_t r = 0; r <= tid; ++r) { simd key_row = slm_block_load( key_slm_offset + r * HD * sizeof(scalar_t)); @@ -647,8 +525,6 @@ void context_attention_kernel_v1( queue.submit(cgf); } -// How about implement a first edition that can be used with non-chunked prefill -// requests, so that we can make sure the reference for heads is correct template void context_attention_kernel_v2( void* query, void* key, void* value, const void* block_tables, @@ -705,7 +581,6 @@ void context_attention_kernel_v2( sycl::queue& queue = vllm::xpu::vllmGetQueue(); auto cgf = [&](sycl::handler& handle) { - // sycl::stream output_stream(128000, 128, handle); sycl::local_accessor dpct_local_acc_ct1( sycl::range<1>(shared_mem_size), handle); sycl::local_accessor q_vecs_acc_ct1( @@ -714,12 +589,8 @@ void context_attention_kernel_v2( sycl::range<1>(2 * NUM_WARPS), handle); handle.parallel_for( - // (batch_size, num_heads, max_input_length * 128) (1, 1, 128) - // Each workgroup handles one token sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - // FIXME: change this... - // const int bsz_idx = item_ct1.get_global_id(0); const int bsz_idx = item_ct1.get_group(0); const int seq_idx = item_ct1.get_group(2); constexpr bool USE_PARTITIONING = false; @@ -736,9 +607,6 @@ void context_attention_kernel_v2( // << context_lens_ptr[bsz_idx] << " Seq_len: " << seq_len // << " Max input length: " << max_input_length // << sycl::endl; - // FIXME: chang this to >= - // Assuming seq_len is 5, then seq_idx should be 0, 1, 2, 3, 4, 5 - // Shall the query token attend to itself? if (context_len >= seq_len) { return; } @@ -750,7 +618,6 @@ void context_attention_kernel_v2( const int num_blocks_per_partition = num_context_blocks; const int start_block_idx = 0; - // TODO: remove this const int end_block_idx = MIN(start_block_idx + num_context_blocks, num_context_blocks); @@ -759,10 +626,7 @@ void context_attention_kernel_v2( const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int num_tokens = end_token_idx - start_token_idx; - // THREAD_GROUP_SIZE equals to 2 constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - // 128 / 2 = 64 THREAD GROUPS -> 4 warps, 16 thread group per - // warp constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE @@ -780,10 +644,6 @@ void context_attention_kernel_v2( constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - // num_tokens, num_heads, HD - // TODO: fix this - // const sycl_t* q_ptr = - // query_ptr + seq_idx * query_stride_bs + head_idx * HD; const sycl_t* q_ptr = query_ptr + (query_loc_ptr[bsz_idx] + seq_idx) * query_stride_bs + head_idx * HD; @@ -801,7 +661,6 @@ void context_attention_kernel_v2( float* logits = reinterpret_cast(shared_mem); constexpr int x = 16 / sizeof(sycl_t); float qk_max = -FLT_MAX; - // TODO: check if block_table include everything? const int* block_table = block_tables_ptr + bsz_idx * block_table_stride_batch; @@ -848,15 +707,8 @@ void context_attention_kernel_v2( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the // masked logits. - // TODO: consider set this to > biger position. - // Consider context_len is 512 (511 real + 1 query token) - // valid token_idx should be in the range of [0, context_len] - // And we shall set this to > - // const bool mask = token_idx >= context_len; const bool mask = token_idx > context_len; - // TODO: uncomment logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); } } @@ -910,6 +762,7 @@ void context_attention_kernel_v2( block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); // Compute softmax. const float inv_sum = 1.f / (exp_sum + 1e-6f); +#pragma unroll for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; } @@ -1046,11 +899,6 @@ void context_attention_kernel_v2( out_p + (query_loc_ptr[bsz_idx] + seq_idx) * out_stride_tokens + head_idx * out_stride_head; - // sycl_t* out_ptr = - // out_p + - // seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - // head_idx * max_num_partitions * HEAD_SIZE + - // partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = @@ -2187,7 +2035,7 @@ torch::Tensor context_attention_forward_v2( torch::Tensor block_tables, torch::Tensor query_start_loc, torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, int max_context_length) { - // TODO: Dispatch to different query.scalar_type() if needed. + // Currently, only support fp16 here int64_t num_tokens = query.size(0); int64_t num_heads = query.size(1); int64_t head_dim = query.size(2); @@ -2198,9 +2046,7 @@ torch::Tensor context_attention_forward_v2( auto output = at::empty({query.size(0), query.size(1), query.size(2)}, at::device(query.device()).dtype(query.dtype())); - // key should be in shape: - // 1. [num_tokens, num_kv_head, head_dim] - assert(key_dimension == 3 or key_dimension == 5); + assert(key_dimension == 5); assert(query.scalar_type() == key.scalar_type() && query.scalar_type() == value.scalar_type()); assert(head_dim == 128); @@ -2265,7 +2111,7 @@ torch::Tensor context_attention_forward_v1( torch::Tensor block_tables, torch::Tensor query_start_loc, torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, int max_context_length) { - // TODO: Dispatch to different query.scalar_type() if needed. + // Currently, only support fp16 int64_t num_tokens = query.size(0); int64_t num_heads = query.size(1); int64_t head_dim = query.size(2); From b641e4488ee6dd62727ec90a819148a2028713f3 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 8 Oct 2024 16:07:35 +0800 Subject: [PATCH 11/12] add head_dim 64 --- csrc/xpu/attention_xpu.cpp | 66 +++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index e1761b539c0f5..8e9b52e751026 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -2089,18 +2089,6 @@ torch::Tensor context_attention_forward_v2( output.stride(0), output.stride(1), num_queries_per_kv, max_input_length, batch_size, num_heads, query.size(0), max_context_length); - - // vllm::context_attention_kernel( - // query.data_ptr(), key.data_ptr(), value.data_ptr(), - // block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - // seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - // output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - // query_stride_token, query_stride_head, query_stride_dim, - // k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - // k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - // v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - // output.stride(0), output.stride(1), num_queries_per_kv, - // max_input_length, batch_size, num_heads); return output; } @@ -2154,28 +2142,34 @@ torch::Tensor context_attention_forward_v1( int v_cache_stride_head = value.stride(1); int v_cache_stride_head_dim = value.stride(2); int v_cache_stride_block = value.stride(3); - // vllm::context_attention_kernel_v2( - // query.data_ptr(), key.data_ptr(), value.data_ptr(), - // block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - // seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - // output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - // query_stride_token, query_stride_head, query_stride_dim, - // k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - // k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - // v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - // output.stride(0), output.stride(1), num_queries_per_kv, max_input_length, - // batch_size, num_heads, query.size(0), max_context_length); - - vllm::context_attention_kernel_v1( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads); + switch(head_dim) { + case 128: + vllm::context_attention_kernel_v1( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads); + break; + case 64: + vllm::context_attention_kernel_v1( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads); + break; + default: throw std::runtime_error("unsupported head_dim"); + } return output; -} \ No newline at end of file +} From 0831fd18ad2ecb44712e1747e383737b66912a6d Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 8 Oct 2024 20:41:18 +0800 Subject: [PATCH 12/12] add more head_dim for v2 --- csrc/xpu/attention_xpu.cpp | 73 +++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp index 8e9b52e751026..24134cfafffa4 100644 --- a/csrc/xpu/attention_xpu.cpp +++ b/csrc/xpu/attention_xpu.cpp @@ -2064,6 +2064,8 @@ torch::Tensor context_attention_forward_v2( // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) // value: [num_blocks, num_kv_heads, head_size, block_dim] int block_size = value.size(3); + // Currently, only block_size 16 is supported... + assert(block_size == 16); int x = key.size(4); int block_table_stride_bsz = block_tables.stride(0); int block_table_stride_seq = block_tables.stride(1); @@ -2077,18 +2079,65 @@ torch::Tensor context_attention_forward_v2( int v_cache_stride_head = value.stride(1); int v_cache_stride_head_dim = value.stride(2); int v_cache_stride_block = value.stride(3); - vllm::context_attention_kernel_v2( - query.data_ptr(), key.data_ptr(), value.data_ptr(), - block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), - seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, - output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, - query_stride_token, query_stride_head, query_stride_dim, - k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, - k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, - v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, - output.stride(0), output.stride(1), num_queries_per_kv, - max_input_length, batch_size, num_heads, query.size(0), - max_context_length); + switch(head_dim) { + case 128: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 64: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 80: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + case 96: + vllm::context_attention_kernel_v2( + query.data_ptr(), key.data_ptr(), value.data_ptr(), + block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), + seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, + output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, + query_stride_token, query_stride_head, query_stride_dim, + k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, + k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, + v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, + output.stride(0), output.stride(1), num_queries_per_kv, + max_input_length, batch_size, num_heads, query.size(0), + max_context_length); + break; + default: throw std::runtime_error("unsupported head_dim"); + } return output; }