From 6beec91a68d7afa49f110962c22e4f6194a02cb2 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 24 Feb 2025 04:59:58 +0000 Subject: [PATCH] hip fixes --- src/ops/inc_multihead_self_attention.cpp | 1059 ++++++++++++----- src/ops/inc_multihead_self_attention.cu | 1 - src/ops/spec_inc_multihead_self_attention.cpp | 576 +++++---- src/ops/tree_inc_multihead_self_attention.cpp | 266 ++--- 4 files changed, 1162 insertions(+), 740 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 6800a32ff..6acdce039 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -53,6 +53,21 @@ __device__ __forceinline__ T #endif } +std::string get_fwd_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + fs::path dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); + } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); +} + template __global__ void store_kv_cache(DT const *devQKVProjArray, DT *kCache_ptr, @@ -60,42 +75,56 @@ __global__ void store_kv_cache(DT const *devQKVProjArray, BatchConfig::PerTokenInfo const *tokenInfos, int num_tokens, int max_seq_len, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; + int head_dim, + int num_q_heads, + int num_kv_heads) { + CUDA_KERNEL_LOOP(i, num_tokens * head_dim * num_kv_heads) { + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + // kCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + // vCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + + // i is iterating over one set of key/val projections from the input + int token_idx = i / (head_dim * num_kv_heads); + int head_idx = (i / head_dim) % num_kv_heads; + int offset = i % head_dim; + + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + int key_src_idx = token_idx * head_dim * tot_num_heads + + head_dim * num_q_heads + head_dim * head_idx + offset; + int val_src_idx = key_src_idx + head_dim * num_kv_heads; - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + int dst_idx = req_id * (head_dim * num_kv_heads * max_seq_len) + + tok_id * head_dim * num_kv_heads + head_idx * head_dim + + offset; - // key cache - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + kCache_ptr[dst_idx] = devQKVProjArray[key_src_idx]; + vCache_ptr[dst_idx] = devQKVProjArray[val_src_idx]; } } template __global__ void store_query_cache(DT const *devQKVProjArray, DT *qCache_ptr, - int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int num_tokens_in_batch, + int first_token_offset_in_batch, + int first_token_depth_in_request, + int head_dim, + int num_q_heads, + int num_kv_heads) { + CUDA_KERNEL_LOOP(i, num_tokens_in_batch * head_dim * num_q_heads) { + int hidden_size = head_dim * num_q_heads; + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + int token_idx = i / hidden_size; int offset = i % hidden_size; + int src_idx = + (first_token_offset_in_batch + token_idx) * (head_dim * tot_num_heads) + + offset; - size_t val_idx = token_idx * QKV_WEIGHT_NUM * hidden_size + offset; - - DT qVal = devQKVProjArray[val_idx]; - - // query cache - qCache_ptr[i] = qVal; + qCache_ptr[first_token_depth_in_request * hidden_size + i] = + devQKVProjArray[src_idx]; } } @@ -126,9 +155,178 @@ bool is_decoding_request(BatchConfig const *bc, int request_id) { !bc->requestsInfo[request_id].prompt_phase; } +template +void run_batched_matmul(IncMultiHeadSelfAttentionMeta const *meta, + hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + void const *alpha, + const DT *A, + hipblasDatatype_t Atype, + int lda, + long long int strideA, + const DT *B, + hipblasDatatype_t Btype, + int ldb, + long long int strideB, + void const *beta, + DT *C, + hipblasDatatype_t Ctype, + int ldc, + long long int strideC, + int batchCount, + hipblasDatatype_t computeType, + hipblasGemmAlgo_t algo, + hipStream_t stream, + int batch_ratio_a, + int batch_ratio_b, + int batch_ratio_c, + bool bwd) { + if (batch_ratio_a == 1 && batch_ratio_b == 1 && batch_ratio_c == 1) { + checkCUDA(hipblasGemmStridedBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + strideA, + B, + Btype, + ldb, + strideB, + beta, + C, + Ctype, + ldc, + strideC, + batchCount, + computeType, + algo)); + } else { + const DT **h_A_array = new const DT *[batchCount]; + const DT **h_B_array = new const DT *[batchCount]; + DT **h_C_array = new DT *[batchCount]; + for (int batch = 0; batch < batchCount; batch++) { + h_A_array[batch] = A + (batch / batch_ratio_a) * strideA; + h_B_array[batch] = B + (batch / batch_ratio_b) * strideB; + h_C_array[batch] = C + (batch / batch_ratio_c) * strideC; + } + assert(sizeof(DT *) == sizeof(void *)); + if (!bwd) { + // Copy pointer arrays to device + checkCUDA(hipMemcpyAsync(meta->d_A_array, + h_A_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + checkCUDA(hipMemcpyAsync(meta->d_B_array, + h_B_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + checkCUDA(hipMemcpyAsync(meta->d_C_array, + h_C_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + + checkCUDA(hipblasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + (void const **)meta->d_A_array, + Atype, + lda, + (void const **)meta->d_B_array, + Btype, + ldb, + beta, + meta->d_C_array, + Ctype, + ldc, + batchCount, + computeType, + algo)); + } else { + checkCUDA(hipMemcpyAsync(meta->d_A_array2, + h_A_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + checkCUDA(hipMemcpyAsync(meta->d_B_array2, + h_B_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + checkCUDA(hipMemcpyAsync(meta->d_C_array2, + h_C_array, + batchCount * sizeof(DT *), + hipMemcpyHostToDevice, + stream)); + + checkCUDA(hipblasGemmBatchedEx(handle, + transa, + transb, + m, + n, + k, + alpha, + (void const **)meta->d_A_array2, + Atype, + lda, + (void const **)meta->d_B_array2, + Btype, + ldb, + beta, + meta->d_C_array2, + Ctype, + ldc, + batchCount, + computeType, + algo)); + } + } +} + +template +__global__ void store_softmax_activation(DT const *qk_prods_softmax, + DT *softmax_activation_buffer, + int num_new_tokens, + int total_tokens, + int max_finetuning_seq_len, + int num_q_heads) { + CUDA_KERNEL_LOOP(i, num_new_tokens * total_tokens * num_q_heads) { + // qk_prods_softmax: [num_new_tokens, total_tokens, num_q_heads] + // softmax activation buffer: [MAX_FINETUNING_LENGTH(num_new_tokens), + // MAX_FINETUNING_LENGTH(total_tokens), num_q_heads] + int tokens_previous_steps = total_tokens - num_new_tokens; + int new_tokens_idx = i % num_new_tokens; + int total_tokens_idx = (i / num_new_tokens) % total_tokens; + int head_idx = i / (num_new_tokens * total_tokens); + int src_idx = head_idx * num_new_tokens * total_tokens + + total_tokens_idx * num_new_tokens + new_tokens_idx; + int dst_idx = head_idx * max_finetuning_seq_len * max_finetuning_seq_len + + total_tokens_idx * max_finetuning_seq_len + + (tokens_previous_steps + new_tokens_idx); + + softmax_activation_buffer[dst_idx] = qk_prods_softmax[src_idx]; + } +} + template void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, + DT *attn_heads, int shard_id, hipStream_t stream) { checkCUDA(hipblasSetStream(m->handle.blas, stream)); @@ -138,32 +336,24 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, assert(data_type_size(m->output_type[0]) == sizeof(DT)); hipblasDatatype_t compute_type = cublas_data_type; - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i] || is_decoding_request(bc, i) || - is_finetuning_bwd_request(bc, i)) { + assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); + int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + int num_processed_prompt_tokens = 0; + for (int req_idx = 0; req_idx < bc->max_requests_per_batch(); req_idx++) { + if (bc->request_completed[req_idx] || is_decoding_request(bc, req_idx) || + is_finetuning_bwd_request(bc, req_idx)) { continue; } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; + int num_new_tokens = bc->requestsInfo[req_idx].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[req_idx].first_token_depth_in_request + + bc->requestsInfo[req_idx].num_tokens_in_batch; if (num_new_tokens <= 0) { continue; } // Copy query to m->query_activation_buffer if we need to compute // PEFT backward - if (bc->requestsInfo[i].finetuning_request && - !bc->requestsInfo[i].finetuning_backward_phase) { + if (bc->requestsInfo[req_idx].finetuning_request && + !bc->requestsInfo[req_idx].finetuning_backward_phase) { // int max_peft_tokens = bc->requestsInfo[i].max_length; int max_peft_tokens = BatchConfig::max_sequence_length(); size_t activation_size_needed = @@ -181,7 +371,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, std::cout << "sizeof(DT)" << sizeof(DT) << std::endl; } assert(activation_size_needed == m->allocated_peft_buffer_size1); - int parallelism = m->hidden_size * num_tokens; + int parallelism = m->qProjSize * m->num_q_heads * num_new_tokens; + int tokens_previous_steps = total_tokens - num_new_tokens; hipLaunchKernelGGL(HIP_KERNEL_NAME(store_query_cache), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), @@ -189,14 +380,18 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, stream, static_cast
(m->devQKVProjArray), static_cast
(m->query_activation_buffer), - num_tokens, - m->hidden_size); + num_new_tokens, + num_processed_prompt_tokens, + tokens_previous_steps, + m->qProjSize, + m->num_q_heads, + m->num_kv_heads); } // Step 1: compute query-key product QK.T/sqrt(d_k) { - // Scale by sqrt(d_k) as per the original attention paper DT alpha = 1.0f, beta = 0.0f; if (*m->qk_prod_scaling) { + // Scale by sqrt(d_k) as per the original attention paper alpha = static_cast
(1.0f / sqrt(m->kProjSize)); } // after transpositions @@ -204,89 +399,118 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, int n = total_tokens; int k = m->qProjSize; // before transpositions - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; + int lda = m->qProjSize * tot_num_heads; + int ldb = m->kProjSize * m->num_kv_heads; + int ldc = num_new_tokens; // N.B. strides are applied before transpose operations - int strideA = q_block_size; - int strideB = kt_block_size; + int strideA = m->qProjSize; + int strideB = m->kProjSize; int strideC = num_new_tokens * total_tokens; // matrix A: devQKVProjArray - // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] + // matrix A's layout: [qProjSize, tot_num_heads, num_new_tokens] // To get query projection, skip over Q entries from previous requests DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + bc->requestsInfo[req_idx].first_token_offset_in_batch * + m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads); // matrix B: key cache - // matrix B's layout: [kProjSize * num_heads, total_tokens] + // matrix B's layout: [kProjSize, num_kv_heads, total_tokens] // To get B, skip over K entries from previous requests (all heads + // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - // To get C, skip over QK.T products from previous requests + DT const *B = static_cast
(m->keyCache) + + req_idx * (m->kProjSize * m->num_kv_heads * + BatchConfig::max_sequence_length()); + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] DT *C = static_cast
(m->qk_prods); - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + run_batched_matmul
(m, + m->handle.blas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + 1, + m->num_q_heads / m->num_kv_heads, + 1); + if (m->inference_debugging) { + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".qk_prods"; + save_tensor(static_cast
(m->qk_prods), + num_new_tokens * total_tokens * m->num_q_heads, + fpath.c_str()); + } } // Step 2: Add alibi position bias to qk production // matrix C: qk_prods // matrix C's layout: [num_new_tokens, total_tokens, num_heads] // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); + { + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } } // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods // with -inf to force causal attention. - assert(num_new_tokens <= total_tokens); - size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - hipLaunchKernelGGL(HIP_KERNEL_NAME(fill_entries_above_diagonal), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - entries_above_diagonal, - static_cast
(-INFINITY)); + { + assert(num_new_tokens <= total_tokens); + size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; + if (entries_above_diagonal > 0) { + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + size_t parallelism = m->num_q_heads * entries_above_diagonal; + hipLaunchKernelGGL(HIP_KERNEL_NAME(fill_entries_above_diagonal), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + entries_above_diagonal, + static_cast
(-INFINITY)); + } + if (m->inference_debugging) { + std::string fpath = + get_fwd_dbg_folder(m, shard_id) + ".qk_prods.masked"; + save_tensor(static_cast
(m->qk_prods), + num_new_tokens * total_tokens * m->num_q_heads, + fpath.c_str()); + } } // Step 4: Compute Softmax(QK.T/sqrt(d_k)) @@ -307,6 +531,11 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, checkCUDNN(miopenSet4dTensorDescriptor( m->qk_tensor, cudnn_data_type, n_param, c_param, h_param, w_param)); float softmax_alpha = 1.0f, softmax_beta = 0.0f; + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + // matrix C_softmax: qk_prods_softmax (current req only) + // matrix C_softmax's layout: [num_new_tokens, total_tokens, num_q_heads] DT *C_softmax = static_cast
(m->qk_prods_softmax); // The softmax operation below is executed according to the // MIOPEN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The @@ -321,21 +550,34 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, C_softmax, MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_CHANNEL)); - } - // Copy C_softmax to m->softmax_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].finetuning_request) { - int max_peft_tokens = BatchConfig::max_sequence_length(); - DT *C_softmax = static_cast
(m->qk_prods_softmax); - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; - assert(activation_size_needed == m->allocated_peft_buffer_size2); - checkCUDA(hipMemcpyAsync(m->softmax_activation_buffer, - C_softmax, - sizeof(DT) * total_tokens * num_new_tokens * - m->num_q_heads, - hipMemcpyDeviceToDevice, - stream)); + // Copy C_softmax to m->softmax_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[req_idx].finetuning_request) { + int max_peft_tokens = BatchConfig::max_sequence_length(); + int max_dataset_entry_size = bc->requestsInfo[req_idx].max_length; + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; + assert(activation_size_needed == m->allocated_peft_buffer_size2); + int parallelism = m->num_q_heads * total_tokens * num_new_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(store_softmax_activation), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + static_cast
(m->qk_prods_softmax), + static_cast
(m->softmax_activation_buffer), + num_new_tokens, + total_tokens, + max_dataset_entry_size, + m->num_q_heads); + } + if (m->inference_debugging) { + std::string fpath = + get_fwd_dbg_folder(m, shard_id) + ".qk_prods_softmax"; + save_tensor(static_cast
(m->qk_prods_softmax), + num_new_tokens * total_tokens * m->num_q_heads, + fpath.c_str()); + } } // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ @@ -347,62 +589,80 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, int n = num_new_tokens; int k = total_tokens; // before transpositions - int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + int lda = m_ * m->num_kv_heads; + int ldb = n; + int ldc = m_ * m->num_q_heads; // N.B. strides are applied before transpose operations - int strideA = vt_block_size; + int strideA = m->vProjSize; int strideB = num_new_tokens * total_tokens; int strideC = m->vProjSize; // matrix A: value cache - // matrix A's layout: [vProjSize, num_heads, total_tokens] + // matrix A's layout: [vProjSize, num_kv_heads, total_tokens] // To get A, skip over V.T entries from previous requests (all heads + // padding) - DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix B: qk_prods_softmax - // matrix B's layout: [num_new_tokens, total_tokens, num_heads] + DT *A = static_cast
(m->valueCache) + + req_idx * (m->vProjSize * m->num_kv_heads * + BatchConfig::max_sequence_length()); + // matrix B: qk_prods_softmax (current req only) + // matrix B's layout: [num_new_tokens, total_tokens, num_q_heads] // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous // requests (all heads) DT *B = static_cast
(m->qk_prods_softmax); // matrix C: attn heads - // matrix C's layout: [vProjSize, num_heads, num_new_tokens] + // matrix C's layout: [vProjSize, num_q_heads, num_new_tokens] // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous // requests // store the result attn heads, also skip the genration tokens - DT *C = static_cast
(m->attn_heads) + - (bc->requestsInfo[i].first_token_offset_in_batch) * + DT *C = static_cast
(attn_heads) + + (bc->requestsInfo[req_idx].first_token_offset_in_batch) * m->num_q_heads * m->vProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + run_batched_matmul
(m, + m->handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + m->num_q_heads / m->num_kv_heads, + 1, + 1); + if (m->inference_debugging) { + std::string fpath = + get_fwd_dbg_folder(m, shard_id) + ".qk_prods_softmax"; + save_tensor(static_cast
(attn_heads), + num_new_tokens * m->num_q_heads * m->vProjSize, + fpath.c_str()); + } } - tokens_previous_requests += num_new_tokens; + num_processed_prompt_tokens += num_new_tokens; } - if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + if (num_processed_prompt_tokens != + (bc->num_active_tokens() - bc->num_generation_tokens)) { bc->print(); - printf("tokens_previous_requests: %i\n", tokens_previous_requests); - printf("num_tokens: %i\n", num_tokens); + printf("num_processed_prompt_tokens: %i\n", num_processed_prompt_tokens); + printf("num_tokens: %i\n", bc->num_active_tokens()); printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); } - assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); + assert(num_processed_prompt_tokens == + (bc->num_active_tokens() - bc->num_generation_tokens)); } // gridDim = num_heads @@ -423,9 +683,12 @@ __global__ void compute_attention_kernel_generation_kernel( float const scale, int max_seq_length, int per_head_size, - int hidden_size, + int num_q_heads, + int num_kv_heads, BatchConfig::PerRequestInfo *request_infos) { + int total_num_heads = num_q_heads + 2 * num_kv_heads; + // q, k using Q_vec = typename VEC_K::Type; using K_vec = typename VEC_K::Type; @@ -449,6 +712,7 @@ __global__ void compute_attention_kernel_generation_kernel( int const tidx = threadIdx.x; // head id int const head_idx = blockIdx.x; + int const kv_head_idx = head_idx / (num_q_heads / num_kv_heads); // request idx int const request_idx = blockIdx.y; @@ -472,7 +736,7 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + request_idx * per_head_size * total_num_heads + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = @@ -506,8 +770,11 @@ __global__ void compute_attention_kernel_generation_kernel( // // The number of keys per warp. constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - DT const *k_cache_batch = - key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + DT const *k_cache_batch = key_cache + + batch_config_request_id * + (per_head_size * num_kv_heads) * + max_seq_length + + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -520,9 +787,9 @@ __global__ void compute_attention_kernel_generation_kernel( for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; if (ti < tlength) { - k[ii] = *reinterpret_cast(k_cache_batch + - ti_circ * hidden_size + - head_idx * per_head_size + jj); + k[ii] = *reinterpret_cast( + k_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size + jj); } // Compute dot product. // This includes a reduction across the threads in the same thread group. @@ -609,8 +876,10 @@ __global__ void compute_attention_kernel_generation_kernel( zero(out); // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; + DT const *v_cache_batch = value_cache + + batch_config_request_id * max_seq_length * + (per_head_size * num_kv_heads) + + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { @@ -618,7 +887,8 @@ __global__ void compute_attention_kernel_generation_kernel( int const ti_circ = ti % max_seq_length; V_vec v = *reinterpret_cast( - v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + v_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size); float logit = qk_smem[ti - first_step]; out = FlexFlow::fma(logit, cast_to_float(v), out); } @@ -656,7 +926,8 @@ __global__ void compute_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { convert_from_float( - *reinterpret_cast(output_ptr + request_idx * hidden_size + + *reinterpret_cast(output_ptr + + request_idx * (per_head_size * num_q_heads) + head_idx * per_head_size + vi), out); } @@ -692,43 +963,48 @@ __global__ void scaling_query_kernel(DT *input_ptr, int qProjSize, int num_tokens, int num_q_heads, - float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *= - scaling_factor; + int num_kv_heads, + float scaling_factor) { + CUDA_KERNEL_LOOP(i, (qProjSize * num_q_heads) * num_tokens) { + int token_idx = i / (qProjSize * num_q_heads); + int offset = i % (qProjSize * num_q_heads); + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + int idx = token_idx * qProjSize * tot_num_heads + offset; + input_ptr[idx] *= scaling_factor; } } template __global__ void - apply_rotary_embedding_hf(DT *input_ptr, - hipFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - float rope_theta, - bool llama3_rope, - float factor, - float low_freq_factor, - float high_freq_factor, - int original_max_position_embeddings, - int qProjSize, - int kProjSize, - int num_tokens, - size_t q_array_size, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + apply_rotary_embedding_fwd(DT *input_ptr, + hipFloatComplex *complex_input, + BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, + int proj_size, + int num_tokens, + int num_q_heads, + int num_kv_heads) { + CUDA_KERNEL_LOOP(i, num_tokens * num_q_heads * proj_size) { + size_t q_array_size = proj_size * num_q_heads * num_tokens; + int hidden_size = num_q_heads * proj_size; + int total_num_heads = num_q_heads + 2 * num_kv_heads; // create complex number bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; int real_i = q_tensor ? i : i - q_array_size / 2; - int token_idx = real_i / (hidden_size / 2); int idx = real_i % (proj_size / 2); int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); + if (!q_tensor) { + head_idx /= (num_q_heads / num_kv_heads); + } int real_part_index = idx + head_idx * proj_size + - token_idx * hidden_size * QKV_WEIGHT_NUM + + token_idx * proj_size * total_num_heads + hidden_size * (q_tensor ? 0 : 1); int complex_part_index = real_part_index + (proj_size / 2); @@ -853,11 +1129,10 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, checkCUDA(hipblasSetStream(m->handle.blas, stream)); checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - assert(m->qSize == m->vSize && m->qSize == m->kSize); + assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); int num_tokens = bc->num_active_tokens(); int parallelism = m->kProjSize * num_tokens * m->num_q_heads; - size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; if (m->scaling_query) { hipLaunchKernelGGL(HIP_KERNEL_NAME(scaling_query_kernel), @@ -869,8 +1144,8 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, m->qProjSize, num_tokens, m->num_q_heads, - m->scaling_factor, - m->hidden_size); + m->num_kv_heads, + m->scaling_factor); } // Step 3: apply rotary embedding if needed @@ -878,7 +1153,7 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, /*q&k*/ parallelism = num_tokens * m->hidden_size; hipLaunchKernelGGL( - HIP_KERNEL_NAME(apply_rotary_embedding_hf), + HIP_KERNEL_NAME(apply_rotary_embedding_fwd), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), 0, @@ -893,10 +1168,9 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, m->rotary_embedding_meta->high_freq_factor, m->rotary_embedding_meta->original_max_position_embeddings, m->qProjSize, - m->kProjSize, num_tokens, - q_array_size, - m->hidden_size); + m->num_q_heads, + m->num_kv_heads); } } @@ -905,8 +1179,13 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, hipStream_t stream) { int num_tokens = bc->num_active_tokens(); + int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + assert(m->hidden_size % m->num_q_heads == 0); + int head_dim = m->hidden_size / m->num_q_heads; + assert(head_dim == m->qProjSize); if (num_tokens > 0) { - int parallelism = m->hidden_size * num_tokens; + int parallelism = head_dim * tot_num_heads * num_tokens; + // devQKVProj has shape [qProjSize, tot_num_heads, num_new_tokens] hipLaunchKernelGGL(HIP_KERNEL_NAME(store_kv_cache), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), @@ -918,7 +1197,9 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, m->token_infos, num_tokens, BatchConfig::max_sequence_length(), - m->hidden_size); + head_dim, + m->num_q_heads, + m->num_kv_heads); } } @@ -942,7 +1223,8 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, scale, \ BatchConfig::max_sequence_length(), \ m->qProjSize, \ - m->hidden_size, \ + m->num_q_heads, \ + m->num_kv_heads, \ m->request_infos) template @@ -967,21 +1249,6 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, } } -std::string get_fwd_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, - int shard_id) { - std::string op_name_without_uid = - IncMultiHeadSelfAttention::get_op_name_without_uid(m); - fs::path dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); - if (m->layer_guid.model_id > 0) { - assert(false && "Model ID > 0 not supported yet"); - } - std::string layername = "layers." + - std::to_string(m->layer_guid.transformer_layer_id) + - "." + op_name_without_uid; - dst_filepath /= layername; - return dst_filepath.string(); -} - template void inference_kernel(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, @@ -991,9 +1258,10 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, hipStream_t stream) { // phase 0: copy calculated qkv into devQKVProjArray - // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + // [qProjSize, tot_num_heads, num_new_tokens] + assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); + size_t tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + size_t qkv_proj_size = m->qProjSize * tot_num_heads * bc->num_active_tokens(); checkCUDA(hipMemcpyAsync(m->devQKVProjArray, qkv_ptr, @@ -1001,20 +1269,46 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, hipMemcpyDeviceToDevice, stream)); + if (m->inference_debugging) { + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".devQKVProjArray"; + save_tensor(static_cast
(m->devQKVProjArray), + qkv_proj_size, + fpath.c_str()); + } + // phase 1: Implement kernel to apply rotary embedding and scaling apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + + if (m->inference_debugging) { + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".post_rope"; + save_tensor(static_cast
(m->devQKVProjArray), + qkv_proj_size, + fpath.c_str()); + } + update_kv_cache_kernel
(m, bc, stream); + if (m->inference_debugging) { + size_t key_cache_size = m->kProjSize * m->num_kv_heads * + BatchConfig::max_sequence_length() * + BatchConfig::max_requests_per_batch(); + std::string fpath = get_fwd_dbg_folder(m, shard_id) + ".key_cache"; + save_tensor( + static_cast
(m->keyCache), key_cache_size, fpath.c_str()); + fpath = get_fwd_dbg_folder(m, shard_id) + ".value_cache"; + save_tensor( + static_cast
(m->valueCache), key_cache_size, fpath.c_str()); + } + if (bc->num_generation_tokens > 0) { // phase 3: Compute attention score for generation tokens - compute_attention_kernel_generation
( - m, bc, static_cast
(m->attn_heads), stream); + compute_attention_kernel_generation
(m, bc, output_ptr, stream); } if (bc->num_tokens > bc->num_generation_tokens) { // phase 4: Compute attention score for prompt tokens; - compute_attention_kernel_prompt
(m, bc, shard_id, stream); + compute_attention_kernel_prompt
(m, bc, output_ptr, shard_id, stream); } if (bc->num_finetuning_fwd_tokens() > 0) { @@ -1031,13 +1325,6 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, bc->tokensInfo[tokens_previous_requests + j]; } } - - int num_tokens = bc->num_active_tokens(); - checkCUDA(hipMemcpyAsync(output_ptr, - m->attn_heads, - m->oProjSize * num_tokens * sizeof(DT), - hipMemcpyDeviceToDevice, - stream)); } std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, @@ -1144,12 +1431,6 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, assert(num_tokens == num_total_tokens); assert(num_total_tokens == bc->requestsInfo[i].max_length); assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); // Step 1: copy gradient before final projection into workspace { @@ -1174,13 +1455,16 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, { float alpha = 1.0f, beta = 0.0f; // matrix A: qk_prods_softmax - // matrix A's layout: [num_new_tokens, total_tokens, num_heads] + // matrix A's layout: [num_new_tokens, total_tokens, num_q_heads] DT const *A = static_cast
(m->qk_prods_softmax); // matrix B: attn_heads gradients - // matrix B's layout: [vProjSize * num_heads, num_new_tokens] + // matrix B's layout: [vProjSize * num_q_heads, num_new_tokens] DT const *B = static_cast
(m->handle.workSpace); // matrix C: gradients for value (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] + // note that we first need to compute the gradients wrt each q_heads, then + // we can sum the gradients corresponding to each group of q_heads to obtain + // the gradients wrt each value head DT *C = static_cast
(m->devQKVProjArray) + 2 * num_tokens * (m->qProjSize * m->num_q_heads); // skip over regions reserved @@ -1234,13 +1518,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, { float alpha = 1.0f, beta = 0.0f; // matrix A: attn_heads gradients - // matrix A's layout: [vProjSize * num_heads, num_new_tokens] + // matrix A's layout: [vProjSize * num_q_heads, num_new_tokens] DT const *A = static_cast
(m->handle.workSpace); // matrix B: value cache - // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->valueCache) + i * vt_req_block_size; + // matrix B's layout: [vProjSize * num_kv_heads, max_num_tokens, num_req] + DT const *B = + static_cast
(m->valueCache) + + i * m->vProjSize * m->num_kv_heads * BatchConfig::max_sequence_length(); // matrix C: qk_prods_softmax gradients - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] DT *C = static_cast
(m->qk_prods_softmax); // after transposition & striding int m_ = num_tokens; // num_new_tokens @@ -1248,43 +1534,51 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, int k_ = m->vProjSize; // before transposition and striding int lda = m->vProjSize * m->num_q_heads; - int ldb = m->vProjSize * m->num_q_heads; + int ldb = m->vProjSize * m->num_kv_heads; int ldc = num_tokens; // num_new_tokens int strideA = m->vProjSize; int strideB = m->vProjSize; int strideC = num_tokens * num_tokens; // num_new_tokens * total_tokens - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + run_batched_matmul
(m, + m->handle.blas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + m_, + n_, + k_, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + 1, + m->num_q_heads / m->num_kv_heads, + 1, + true); if (m->inference_debugging) { std::string filename = get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad"; save_tensor( C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); std::string filename2 = get_peft_dbg_folder(m, shard_id) + ".vcache"; - save_tensor( - B, m->vProjSize * m->num_q_heads * num_tokens, filename2.c_str()); + save_tensor(B, + m->vProjSize * m->num_kv_heads * + BatchConfig::max_sequence_length(), + filename2.c_str()); } } // Step 4: softmax backpropagation @@ -1314,6 +1608,11 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad_in"; save_tensor( C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + filename = + get_peft_dbg_folder(m, shard_id) + ".softmax_activation_buffer"; + save_tensor(static_cast
(m->softmax_activation_buffer), + num_tokens * num_tokens * m->num_q_heads, + filename.c_str()); } // TODO: fill all elements above diagonal to force causal attention @@ -1347,13 +1646,16 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, alpha = 1.0f / sqrt(m->kProjSize); } // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_heads] + // matrix A's layout: [num_new_tokens, num_tokens, num_q_heads] DT const *A = static_cast
(m->qk_prods); // matrix B: query activation (in query_activation_buffer) - // matrix B's layout: [m->qProjSize * num_heads, num_new_tokens] + // matrix B's layout: [m->qProjSize * num_q_heads, num_new_tokens] DT const *B = static_cast
(m->query_activation_buffer); // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] + // note that we first need to compute the gradients wrt each q_heads, then + // we can sum the gradients corresponding to each group of q_heads to obtain + // the gradients wrt each key head DT *C = static_cast
(m->devQKVProjArray) + num_tokens * (m->qProjSize * @@ -1410,13 +1712,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, alpha = 1.0f / sqrt(m->kProjSize); } // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_heads] + // matrix A's layout: [num_new_tokens, num_tokens, num_q_heads] DT const *A = static_cast
(m->qk_prods); // matrix B: key cache - // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + // matrix B's layout: [vProjSize * num_kv_heads, max_num_tokens, num_req] + DT const *B = + static_cast
(m->keyCache) + + i * m->kProjSize * m->num_kv_heads * BatchConfig::max_sequence_length(); // matrix C: gradients for query (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + // matrix C's layout: [num_tokens, qProjsize * num_q_heads, 3] DT *C = static_cast
(m->devQKVProjArray); // after transposition & striding int m_ = num_tokens; // num_new_tokens @@ -1424,34 +1728,40 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, int k_ = num_tokens; // before transposition and striding int lda = num_tokens; // num_new_tokens - int ldb = m->qProjSize * m->num_q_heads; + int ldb = m->qProjSize * m->num_kv_heads; int ldc = num_tokens; int strideA = num_tokens * num_tokens; int strideB = m->qProjSize; int strideC = num_tokens * m->qProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + run_batched_matmul
(m, + m->handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_T, + m_, + n_, + k_, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + 1, + m->num_q_heads / m->num_kv_heads, + 1, + true); if (m->inference_debugging) { std::string filename = get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; @@ -1529,7 +1839,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; // int m_ = m->qSize; int n_ = num_tokens; - int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); + int k_ = m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads); // The original version uses existing result and attention's projection to // do further calculation in a way different than the usual dense layer, @@ -1731,21 +2041,31 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( position_bias = (bool *)calloc(1, sizeof(bool)); *position_bias = _position_bias; + assert(num_q_heads % num_kv_heads == 0 && + "num_q_heads must be divisible by num_kv_heads"); + if (num_q_heads > num_kv_heads) { + // grouped query attention + assert(attn->data_type == DT_FLOAT || + attn->data_type == DT_HALF && "Unsupported data type"); + gqa_ptr_array_size = num_q_heads * sizeof(void *); + } + // allocate memory for the seqArray and reserve space { int max_tokens_per_batch = infer_mode == TREE_VERIFY_MODE ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(); - size_t qkv_max_proj_size = max_tokens_per_batch * (qProjSize * num_q_heads + - kProjSize * num_q_heads + - vProjSize * num_q_heads); - size_t key_cache_size = 0, value_cache_size = 0; + size_t qkv_max_proj_size = + max_tokens_per_batch * + (qProjSize * num_q_heads + kProjSize * num_kv_heads + + vProjSize * num_kv_heads); + size_t query_tmp_size = 0, key_cache_size = 0, value_cache_size = 0; switch (infer_mode) { case INC_DECODING_MODE: { - key_cache_size = num_q_heads * kProjSize * + key_cache_size = num_kv_heads * kProjSize * BatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length(); - value_cache_size = num_q_heads * vProjSize * + value_cache_size = num_kv_heads * vProjSize * BatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length(); break; @@ -1753,11 +2073,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( case BEAM_SEARCH_MODE: case TREE_VERIFY_MODE: { // a K-ary tree max node is (k^n - 1) / 2 - key_cache_size = num_q_heads * kProjSize * + key_cache_size = num_kv_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * (BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num()); - value_cache_size = num_q_heads * vProjSize * + value_cache_size = num_kv_heads * vProjSize * BeamSearchBatchConfig::max_requests_per_batch() * (BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num()); @@ -1792,30 +2112,33 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( peft_token_infos = nullptr; peft_token_infos_size = 0; } - size_t totalSize = - (qkv_max_proj_size + key_cache_size + value_cache_size + - 2 * qk_prod_size + attn_heads_size) * - size_of_dt + - complex_size * sizeof(hipFloatComplex); // more components will - // be added here later + size_t totalSize = (qkv_max_proj_size + query_tmp_size + key_cache_size + + value_cache_size + 2 * qk_prod_size + attn_heads_size) * + size_of_dt + + complex_size * sizeof(hipFloatComplex) + + 3 * gqa_ptr_array_size; if (enable_peft_finetuning) { totalSize += allocated_peft_buffer_size1 + allocated_peft_buffer_size2; totalSize += peft_token_infos_size; + totalSize += 3 * gqa_ptr_array_size; } if (offload) { // assert that we have enough reserved work space left size_t totalSharedSize = infer_mode == TREE_VERIFY_MODE - ? totalSize - - (key_cache_size + value_cache_size + qkv_max_proj_size) * - size_of_dt - : totalSize - (key_cache_size + value_cache_size) * size_of_dt; + ? totalSize - (query_tmp_size + key_cache_size + + value_cache_size + qkv_max_proj_size) * + size_of_dt + : totalSize - + (query_tmp_size + key_cache_size + value_cache_size) * + size_of_dt; size_t instance_size = size_of_dt * (infer_mode == TREE_VERIFY_MODE - ? key_cache_size + value_cache_size + qkv_max_proj_size - : key_cache_size + value_cache_size); + ? query_tmp_size + key_cache_size + value_cache_size + + qkv_max_proj_size + : query_tmp_size + key_cache_size + value_cache_size); assert(gpu_mem_allocator.reserved_total_size - gpu_mem_allocator.reserved_allocated_size >= @@ -1843,6 +2166,26 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( valueCache = gpu_mem_allocator.allocate_instance_untyped(value_cache_size * size_of_dt); + // gqa pointers + if (num_q_heads > num_kv_heads) { + assert(num_q_heads % num_kv_heads == 0 && + "Num Q heads must be a multiple of num KV heads"); + d_A_array = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + d_B_array = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + d_C_array = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + if (enable_peft_finetuning) { + d_A_array2 = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + d_B_array2 = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + d_C_array2 = (void **)gpu_mem_allocator.allocate_instance_untyped( + gqa_ptr_array_size); + } + } + token_infos = static_cast( handler.batch_config_metadata->tokens_info); request_infos = static_cast( @@ -1897,6 +2240,68 @@ IncMultiHeadSelfAttentionMeta::~IncMultiHeadSelfAttentionMeta(void) { } } +template void Kernels::IncMultiHeadAttention::run_batched_matmul( + IncMultiHeadSelfAttentionMeta const *meta, + hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + void const *alpha, + half const *A, + hipblasDatatype_t Atype, + int lda, + long long int strideA, + half const *B, + hipblasDatatype_t Btype, + int ldb, + long long int strideB, + void const *beta, + half *C, + hipblasDatatype_t Ctype, + int ldc, + long long int strideC, + int batchCount, + hipblasDatatype_t computeType, + hipblasGemmAlgo_t algo, + hipStream_t stream, + int batch_ratio_a, + int batch_ratio_b, + int batch_ratio_c, + bool bwd); + +template void Kernels::IncMultiHeadAttention::run_batched_matmul( + IncMultiHeadSelfAttentionMeta const *meta, + hipblasHandle_t handle, + hipblasOperation_t transa, + hipblasOperation_t transb, + int m, + int n, + int k, + void const *alpha, + float const *A, + hipblasDatatype_t Atype, + int lda, + long long int strideA, + float const *B, + hipblasDatatype_t Btype, + int ldb, + long long int strideB, + void const *beta, + float *C, + hipblasDatatype_t Ctype, + int ldc, + long long int strideC, + int batchCount, + hipblasDatatype_t computeType, + hipblasGemmAlgo_t algo, + hipStream_t stream, + int batch_ratio_a, + int batch_ratio_b, + int batch_ratio_c, + bool bwd); + template void Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( IncMultiHeadSelfAttentionMeta const *m, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 73227c870..874e8a02e 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1114,7 +1114,6 @@ void apply_scaling_and_rotary(IncMultiHeadSelfAttentionMeta const *m, int parallelism = m->kProjSize * num_tokens * m->num_q_heads; if (m->scaling_query) { - int parallelism = m->qProjSize * m->num_q_heads * num_tokens; scaling_query_kernel<<::Type; using K_vec = typename VEC_K::Type; @@ -90,7 +94,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( int const tidx = threadIdx.x; // head id int const head_idx = blockIdx.x; - // nth request idx + int const kv_head_idx = head_idx / (num_q_heads / num_kv_heads); + // request idx int const request_idx = blockIdx.y; // request id in batch config @@ -129,7 +134,7 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + first_token_idx * per_head_size * total_num_heads + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; @@ -146,8 +151,11 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( // The number of keys per warp. constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - DT const *k_cache_batch = - key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + DT const *k_cache_batch = key_cache + + batch_config_request_id * + (per_head_size * num_kv_heads) * + max_seq_length + + ki; int ti_end = div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -156,7 +164,7 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { q_vecs[ki_o][ii] = *reinterpret_cast( - q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + q_ptr + (per_head_size * total_num_heads * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); } @@ -173,8 +181,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( if (ti < totalCacheSize) { k[ii] = *reinterpret_cast( - k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + - jj); + k_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size + jj); } } float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); @@ -270,16 +278,19 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( zero(out); // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + - vi; + DT const *v_cache_batch = value_cache + + batch_config_request_id * max_seq_length * + (per_head_size * num_kv_heads) + + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { // Load the values from the cache. int const ti_circ = ti % max_seq_length; + V_vec v = *reinterpret_cast( - v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + v_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size); bool const mask = (ti >= bitmask.non_tree_cache_size && (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & @@ -320,10 +331,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float(*reinterpret_cast( - output_ptr + (first_token_idx + qi) * hidden_size + - head_idx * per_head_size + vi), - out); + convert_from_float( + *reinterpret_cast(output_ptr + + (first_token_idx + qi) * + (per_head_size * num_q_heads) + + head_idx * per_head_size + vi), + out); } } } @@ -338,42 +351,40 @@ __global__ void spec_inc_store_kv_cache( BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, BatchConfig::BitMask *causalMask, - int qProjSize, - int kProjSize, - int vProjSize, + bool is_root, int num_tokens, int max_seq_len, - bool is_root, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / (hidden_size); - int offset = i % hidden_size; - - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; + int head_dim, + int num_q_heads, + int num_kv_heads) { + CUDA_KERNEL_LOOP(i, num_tokens * head_dim * num_kv_heads) { + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + // kCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + // vCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + int token_idx = i / (head_dim * num_kv_heads); + int head_idx = (i / head_dim) % num_kv_heads; + int offset = i % head_dim; + + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + int key_src_idx = token_idx * head_dim * tot_num_heads + + head_dim * num_q_heads + head_dim * head_idx + offset; + int val_src_idx = key_src_idx + head_dim * num_kv_heads; int const req_id = tokenInfos[token_idx].request_index; // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; int const request_token_offset = requestInfo[req_id].first_token_offset_in_batch; - BatchConfig::BitMask bitmask = causalMask[req_id]; - - // if prompt token -> token id - // if tree token: - int const cache_idx = bitmask.prompt_size + bitmask.non_tree_cache_size + bitmask.tree_size - 1 - bitmask.this_layer_size + token_idx - request_token_offset; + int dst_idx = req_id * (head_dim * num_kv_heads * max_seq_len) + + cache_idx * head_dim * num_kv_heads + head_idx * head_dim + + offset; - kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + - offset] = vVal; + kCache_ptr[dst_idx] = devQKVProjArray[key_src_idx]; + vCache_ptr[dst_idx] = devQKVProjArray[val_src_idx]; } } @@ -382,9 +393,13 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, BeamSearchBatchConfig const *bc, hipStream_t stream) { int num_tokens = bc->num_active_tokens(); + int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + assert(m->hidden_size % m->num_q_heads == 0); + int head_dim = m->hidden_size / m->num_q_heads; + assert(head_dim == m->qProjSize); int curr_depth = bc->beamRequestsInfo[0].current_depth; if (num_tokens > 0) { - int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; + int parallelism = head_dim * tot_num_heads * num_tokens; hipLaunchKernelGGL(HIP_KERNEL_NAME(spec_inc_store_kv_cache
), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), @@ -398,14 +413,13 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, m->beam_token_infos, m->beam_request_infos, m->causalMask, - m->qProjSize, - m->kProjSize, - m->vProjSize, + /*root*/ curr_depth == 0, num_tokens, BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num(), - /*root*/ curr_depth == 0, - m->hidden_size); + head_dim, + m->num_q_heads, + m->num_kv_heads); } } @@ -431,7 +445,8 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, BatchConfig::max_sequence_length() + \ BatchConfig::max_spec_tree_token_num(), \ m->qProjSize, \ - m->hidden_size, \ + m->num_q_heads, \ + m->num_kv_heads, \ m->request_infos, \ m->beam_request_infos, \ m->causalMask, \ @@ -483,7 +498,7 @@ template void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, BeamSearchBatchConfig const *bc, int shard_id, - DT *output_ptr, + DT *attn_heads, hipStream_t stream) { checkCUDA(hipblasSetStream(m->handle.blas, stream)); checkCUDNN(miopenSetStream(m->handle.dnn, stream)); @@ -492,209 +507,259 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, assert(data_type_size(m->output_type[0]) == sizeof(DT)); hipblasDatatype_t compute_type = hipblas_data_type; - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int tokens_prev_requests_squares = 0; - int q_block_size = m->qProjSize; - - int kt_block_size = m->kProjSize; - int kt_req_block_size = kt_block_size * m->num_q_heads * - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num()); - int vt_block_size = m->vProjSize; - int vt_req_block_size = vt_block_size * m->num_q_heads * - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num()); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase) || - (bc->requestsInfo[i].num_tokens_in_batch == 0)) { - continue; - } else if (tokens_previous_requests < bc->num_generation_tokens) { - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); + int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + int max_seq_len = BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num(); + + int num_processed_prompt_tokens = 0; + for (int req_idx = 0; req_idx < bc->max_requests_per_batch(); req_idx++) { + if (bc->request_completed[req_idx] || + (!bc->requestsInfo[req_idx].prompt_phase) || + (bc->requestsInfo[req_idx].num_tokens_in_batch == 0)) { continue; } // all requests in prompt phase should only have one sub requests; - assert(bc->sub_requests[i] == 1); - // int num_new_tokens = bc->num_processing_tokens[i]; - // int total_tokens = bc->token_last_available_idx[i] + 1; + assert(bc->sub_requests[req_idx] == 1); - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; + int num_new_tokens = bc->requestsInfo[req_idx].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[req_idx].first_token_depth_in_request + + bc->requestsInfo[req_idx].num_tokens_in_batch; if (num_new_tokens <= 0) { continue; } - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + // Step 1: compute query-key product QK.T/sqrt(d_k) + { + // Scale by sqrt(d_k) as per the original attention paper + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // after transpositions + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + // before transpositions + int lda = m->qProjSize * tot_num_heads; + int ldb = m->kProjSize * m->num_kv_heads; + int ldc = num_new_tokens; + // N.B. strides are applied before transpose operations + int strideA = m->qProjSize; + int strideB = m->kProjSize; + int strideC = num_new_tokens * total_tokens; + + // matrix A: devQKVProjArray + // matrix A's layout: [qProjSize, tot_num_heads, num_new_tokens] + // To get query projection, skip over Q entries from previous requests + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[req_idx].first_token_offset_in_batch * + m->qProjSize * (m->num_q_heads + 2 * m->num_kv_heads); + // matrix B: key cache + // matrix B's layout: [kProjSize, num_kv_heads, total_tokens] + // To get B, skip over K entries from previous requests (all heads + + // padding) + DT const *B = static_cast
(m->keyCache) + + req_idx * (m->kProjSize * m->num_kv_heads * max_seq_len); + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + Kernels::IncMultiHeadAttention::run_batched_matmul
( + m, + m->handle.blas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + hipblas_data_type, + lda, + strideA, + B, + hipblas_data_type, + ldb, + strideB, + &beta, + C, + hipblas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + 1, + m->num_q_heads / m->num_kv_heads, + 1); + } + + // Step 2: Add alibi position bias to qk production + { + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + Kernels::IncMultiHeadAttention::apply_position_bias_qkprd
), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - DT *C = static_cast
(m->qk_prods); - - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - // add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); + + // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods + // with -inf to force causal attention. + { + assert(num_new_tokens <= total_tokens); + if (num_new_tokens > 1) { + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; + hipLaunchKernelGGL( + HIP_KERNEL_NAME(spec_fill_entries_above_diagonal
), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + static_cast
(-INFINITY)); + } } - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens); - if (num_new_tokens > 1) { - size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(spec_fill_entries_above_diagonal
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - static_cast
(-INFINITY)); + // Step 4: Compute Softmax(QK.T/sqrt(d_k)) + { + // Compute Softmax(QK^T/sqrt(d_k)) + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(miopenSet4dTensorDescriptor( + m->qk_tensor, miopen_data_type, n_param, c_param, h_param, w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + // matrix C: qk_prods (current req only) + // matrix C's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C = static_cast
(m->qk_prods); + // matrix C_softmax: qk_prods_softmax (current req only) + // matrix C_softmax's layout: [num_new_tokens, total_tokens, num_q_heads] + DT *C_softmax = static_cast
(m->qk_prods_softmax); + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax, + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_CHANNEL)); } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(miopenSet4dTensorDescriptor( - m->qk_tensor, miopen_data_type, n_param, c_param, h_param, w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax) + - m->num_q_heads * tokens_prev_requests_squares; - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax, - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_CHANNEL)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + i * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - - int token_offset = bc->requestsInfo[i].first_token_offset_in_batch; - - C = static_cast
(m->attn_heads) + - (token_offset)*m->num_q_heads * m->vProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - tokens_previous_requests += num_new_tokens; - tokens_prev_requests_squares += num_new_tokens * total_tokens; + + // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ + // softmax(QK.T/sqrt(d_k)).T + { + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_ = m->vProjSize; + int n = num_new_tokens; + int k = total_tokens; + // before transpositions + int lda = m_ * m->num_kv_heads; + int ldb = n; + int ldc = m_ * m->num_q_heads; + // N.B. strides are applied before transpose operations + int strideA = m->vProjSize; + int strideB = num_new_tokens * total_tokens; + int strideC = m->vProjSize; + // matrix A: value cache + // matrix A's layout: [vProjSize, num_kv_heads, total_tokens] + // To get A, skip over V.T entries from previous requests (all heads + + // padding) + DT *A = static_cast
(m->valueCache) + + req_idx * (m->vProjSize * m->num_kv_heads * max_seq_len); + // matrix B: qk_prods_softmax (current req only) + // matrix B's layout: [num_new_tokens, total_tokens, num_q_heads] + // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous + // requests (all heads) + DT *B = static_cast
(m->qk_prods_softmax); + // matrix C: attn heads + // matrix C's layout: [vProjSize, num_q_heads, num_new_tokens] + // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous + // requests + // store the result attn heads, also skip the genration tokens + DT *C = static_cast
(attn_heads) + + (bc->requestsInfo[req_idx].first_token_offset_in_batch) * + m->num_q_heads * m->vProjSize; + Kernels::IncMultiHeadAttention::run_batched_matmul
( + m, + m->handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + hipblas_data_type, + lda, + strideA, + B, + hipblas_data_type, + ldb, + strideB, + &beta, + C, + hipblas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT, + stream, + m->num_q_heads / m->num_kv_heads, + 1, + 1); + } + + num_processed_prompt_tokens += num_new_tokens; } - if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + if (num_processed_prompt_tokens != + (bc->num_active_tokens() - bc->num_generation_tokens)) { bc->print(); - printf("tokens_previous_requests: %i\n", tokens_previous_requests); - printf("num_tokens: %i\n", num_tokens); + printf("num_processed_prompt_tokens: %i\n", num_processed_prompt_tokens); + printf("bc->num_active_tokens(): %i\n", bc->num_active_tokens()); printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); } - assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); + assert(num_processed_prompt_tokens == + (bc->num_active_tokens() - bc->num_generation_tokens)); } template @@ -705,41 +770,30 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, DT *output_ptr, hipStream_t stream) { - // phase 0: copy calculated qkv into devQKVProjArray - // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); - - checkCUDA(hipMemcpyAsync( - m->devQKVProjArray, - qkv_ptr, - qkv_proj_size * sizeof(DT), // is this right, do we need layers etc here - hipMemcpyDeviceToDevice, - stream)); - // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn - // first - apply_scaling_and_rotary( + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); + size_t tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + size_t qkv_proj_size = m->qProjSize * tot_num_heads * bc->num_active_tokens(); + + checkCUDA(hipMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * sizeof(DT), + hipMemcpyDeviceToDevice, + stream)); + + // phase 1: Apply scaling and rotary embedding + Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); + if (bc->num_generation_tokens > 0) { - compute_spec_inc_attention_kernel_generation
( - m, bc, static_cast
(m->attn_heads), stream); + compute_spec_inc_attention_kernel_generation
(m, bc, output_ptr, stream); } - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 if (bc->num_tokens > bc->num_generation_tokens) { compute_attention_kernel_prompt(m, bc, shard_id, output_ptr, stream); } - - int num_tokens = bc->num_active_tokens(); - - checkCUDA(hipMemcpyAsync(output_ptr, - m->attn_heads, - m->oProjSize * num_tokens * sizeof(DT), - hipMemcpyDeviceToDevice, - stream)); } } // namespace SpecIncMultiHeadSelfAttention diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 2bfa88bdc..d6d258de1 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -29,8 +29,6 @@ using Legion::Memory; #define WARP_SIZE 32 -using namespace Kernels::IncMultiHeadAttention; - namespace Kernels { namespace TreeIncMultiHeadAttention { @@ -69,14 +67,16 @@ __global__ void compute_attention_kernel_fused_kernel( int const max_seq_length, int const max_token_per_batch, int per_head_size, - int hidden_size, + int num_q_heads, + int num_kv_heads, BatchConfig::PerRequestInfo *request_infos, - int num_heads, int num_requests, BatchConfig::BitMask *causalMask, bool *request_completed, int qk_smem_sz) { + int total_num_heads = num_q_heads + 2 * num_kv_heads; + // q, k using Q_vec = typename VEC_K::Type; using K_vec = typename VEC_K::Type; @@ -94,6 +94,7 @@ __global__ void compute_attention_kernel_fused_kernel( int const tidx = threadIdx.x; // head id int const head_idx = blockIdx.x; + int const kv_head_idx = head_idx / (num_q_heads / num_kv_heads); // request idx int const request_idx = blockIdx.y; @@ -131,7 +132,7 @@ __global__ void compute_attention_kernel_fused_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + first_token_idx * per_head_size * total_num_heads + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; @@ -148,8 +149,11 @@ __global__ void compute_attention_kernel_fused_kernel( // The number of keys per warp. constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - DT const *k_cache_batch = - key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + DT const *k_cache_batch = key_cache + + batch_config_request_id * + (per_head_size * num_kv_heads) * + max_seq_length + + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -158,7 +162,7 @@ __global__ void compute_attention_kernel_fused_kernel( #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { q_vecs[ki_o][ii] = *reinterpret_cast( - q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + q_ptr + (per_head_size * total_num_heads * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); // if (head_idx == 0 && request_idx == 1 && tidx == 0) { @@ -177,8 +181,8 @@ __global__ void compute_attention_kernel_fused_kernel( int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; if (ti < tlength) { k[ii] = *reinterpret_cast( - k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + - jj); + k_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size + jj); } } float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); @@ -279,9 +283,10 @@ __global__ void compute_attention_kernel_fused_kernel( zero(out); // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + - vi; + DT const *v_cache_batch = value_cache + + batch_config_request_id * max_seq_length * + (per_head_size * num_kv_heads) + + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { @@ -289,7 +294,8 @@ __global__ void compute_attention_kernel_fused_kernel( int const ti_circ = ti % max_seq_length; // int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; V_vec v = *reinterpret_cast( - v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + v_cache_batch + ti_circ * (per_head_size * num_kv_heads) + + kv_head_idx * per_head_size); if (ti < tlength) { bool const mask = @@ -335,10 +341,12 @@ __global__ void compute_attention_kernel_fused_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float(*reinterpret_cast( - output_ptr + (first_token_idx + qi) * hidden_size + - head_idx * per_head_size + vi), - out); + convert_from_float( + *reinterpret_cast(output_ptr + + (first_token_idx + qi) * + (per_head_size * num_q_heads) + + head_idx * per_head_size + vi), + out); // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { // printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", // out.x, @@ -360,34 +368,36 @@ __global__ void commit_tokens_kernel( DT *kCache_ptr, DT *vCache_ptr, TreeVerifyBatchConfig::CommittedTokensInfo const *committedTokenInfos, - int qProjSize, - int kProjSize, - int vProjSize, + int head_dim, + int num_q_heads, + int num_kv_heads, int num_tokens_to_commit, int num_active_tokens_in_last_batch, - int max_seq_len, - int hidden_size) { - - CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size) { - - int token_pos = i / (hidden_size); - int token_idx_in_last_batch = committedTokenInfos[token_pos].token_index; - int offset = i % hidden_size; + int max_seq_len) { + + CUDA_KERNEL_LOOP(i, num_tokens_to_commit * head_dim * num_kv_heads) { + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + // kCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + // vCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + int token_idx = i / (head_dim * num_kv_heads); + int token_idx_in_last_batch = committedTokenInfos[token_idx].token_index; + int head_idx = (i / head_dim) % num_kv_heads; + int offset = i % head_dim; assert(token_idx_in_last_batch < num_active_tokens_in_last_batch); - size_t val_idx = token_idx_in_last_batch * QKV_WEIGHT_NUM * hidden_size + - hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + int key_src_idx = token_idx_in_last_batch * head_dim * tot_num_heads + + head_dim * num_q_heads + head_dim * head_idx + offset; + int val_src_idx = key_src_idx + head_dim * num_kv_heads; - int const req_id = committedTokenInfos[token_pos].request_index; - int const tok_id = committedTokenInfos[token_pos].token_depth; + int const req_id = committedTokenInfos[token_idx].request_index; + int const tok_id = committedTokenInfos[token_idx].token_depth; + int dst_idx = req_id * (head_dim * num_kv_heads * max_seq_len) + + tok_id * head_dim * num_kv_heads + head_idx * head_dim + + offset; - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + kCache_ptr[dst_idx] = devQKVProjArray[key_src_idx]; + vCache_ptr[dst_idx] = devQKVProjArray[val_src_idx]; } } @@ -395,9 +405,12 @@ template void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, TreeVerifyBatchConfig const *bc, hipStream_t stream) { + int head_dim = m->hidden_size / m->num_q_heads; + assert(head_dim == m->qProjSize); + // int tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; int num_tokens_to_commit = bc->num_tokens_to_commit; if (num_tokens_to_commit > 0) { - int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens_to_commit; + int parallelism = head_dim * m->num_kv_heads * num_tokens_to_commit; hipLaunchKernelGGL( HIP_KERNEL_NAME(commit_tokens_kernel
), GET_BLOCKS(parallelism), @@ -409,48 +422,12 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, static_cast
(m->valueCache), m->committed_token_infos, m->qProjSize, - m->kProjSize, - m->vProjSize, + m->num_q_heads, + m->num_kv_heads, num_tokens_to_commit, m->num_active_infr_tokens, // number of active tokens in previous batch BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(), - m->hidden_size); - } -} - -template -__global__ void update_tree_branch_kv_cache( - DT const *devQKVProjArray, - DT *kCache_ptr, - DT *vCache_ptr, - TreeVerifyBatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int vProjSize, - int num_tokens_in_branch, - int processed_tokens_in_batch, - int total_tokens_in_batch, - int max_seq_len, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens_in_branch * hidden_size) { - - int token_idx = i / (hidden_size); - int offset = i % hidden_size; - - token_idx += processed_tokens_in_batch; // get index in the whole batch - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; - - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + BatchConfig::max_spec_tree_token_num()); } } @@ -461,21 +438,25 @@ __global__ void update_tree_branch_kv_cache_fused( DT *vCache_ptr, TreeVerifyBatchConfig::PerTokenInfo const *tokenInfos, BatchConfig::PerRequestInfo *request_infos, - int qProjSize, - int kProjSize, - int vProjSize, + int head_dim, + int num_q_heads, + int num_kv_heads, int num_new_tokens, - int max_seq_len, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_new_tokens * hidden_size) { + int max_seq_len) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; + CUDA_KERNEL_LOOP(i, num_new_tokens * head_dim * num_kv_heads) { + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + // kCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + // vCache_ptr: [head_dim, num_kv_heads, max_seq_len, max_batch_size] + int token_idx = i / (head_dim * num_kv_heads); + int head_idx = (i / head_dim) % num_kv_heads; + int offset = i % head_dim; - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; + int tot_num_heads = num_q_heads + 2 * num_kv_heads; + + int key_src_idx = token_idx * head_dim * tot_num_heads + + head_dim * num_q_heads + head_dim * head_idx + offset; + int val_src_idx = key_src_idx + head_dim * num_kv_heads; int const req_id = tokenInfos[token_idx].request_index; // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; @@ -485,19 +466,13 @@ __global__ void update_tree_branch_kv_cache_fused( int const first_token_depth = request_infos[req_id].first_token_depth_in_request; - // if(i % hidden_size == 0){ - // printf("update token request id: %d, %d, %d real id %d, value%.10f\n", - // req_id, token_idx, request_token_offset,(token_idx + first_token_depth - // - request_token_offset), kVal); - // } - kCache_ptr[req_id * (hidden_size * max_seq_len) + - (token_idx + first_token_depth - request_token_offset) * - hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + - (token_idx + first_token_depth - request_token_offset) * - hidden_size + - offset] = vVal; + int dst_idx = req_id * (head_dim * num_kv_heads * max_seq_len) + + (token_idx + first_token_depth - request_token_offset) * + head_dim * num_kv_heads + + head_idx * head_dim + offset; + + kCache_ptr[dst_idx] = devQKVProjArray[key_src_idx]; + vCache_ptr[dst_idx] = devQKVProjArray[val_src_idx]; } } @@ -543,9 +518,9 @@ __global__ void tree_fill_entries_above_diagonal(DT *matrix, BatchConfig::BatchConfig::max_spec_tree_token_num(), \ BatchConfig::max_tokens_per_batch(), \ m->qProjSize, \ - m->hidden_size, \ - m->request_infos, \ m->num_q_heads, \ + m->num_kv_heads, \ + m->request_infos, \ bc->num_active_requests(), \ m->causalMask, \ m->request_completed, \ @@ -559,24 +534,26 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, // update the kv cache // update K-V cache + int head_dim = m->hidden_size / m->num_q_heads; + assert(head_dim == m->qProjSize); int num_new_tokens = bc->num_active_tokens(); - int parallelism = m->hidden_size * num_new_tokens; - update_tree_branch_kv_cache_fused<<>>( - static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->request_infos, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_new_tokens, - BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(), - m->hidden_size); + int parallelism = head_dim * m->num_kv_heads * num_new_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(update_tree_branch_kv_cache_fused), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + m->request_infos, + m->qProjSize, + m->num_q_heads, + m->num_kv_heads, + num_new_tokens, + BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; @@ -617,36 +594,23 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // tokens for the current batch m->num_active_infr_tokens = bc->num_active_tokens(); - // phase 0: copy calculated qkv into devQKVProjArray - // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); - - checkCUDA(hipMemcpyAsync( - m->devQKVProjArray, - qkv_ptr, - qkv_proj_size * sizeof(DT), // is this right, do we need layers etc here - hipMemcpyDeviceToDevice, - stream)); - - // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn - // first - apply_scaling_and_rotary( - m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + // devQKVProjArray: [head_dim, tot_num_heads, num_tokens] + assert(m->qProjSize == m->kProjSize && m->qProjSize == m->vProjSize); + size_t tot_num_heads = m->num_q_heads + 2 * m->num_kv_heads; + size_t qkv_proj_size = m->qProjSize * tot_num_heads * bc->num_active_tokens(); - // phase 2: No need to update key/val cache - compute_attention_kernel_fused
( - m, bc, static_cast
(m->attn_heads), stream); + checkCUDA(hipMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * sizeof(DT), + hipMemcpyDeviceToDevice, + stream)); - int processed_tokens_in_batch = bc->num_active_tokens(); + // phase 1: Apply scaling and rotary embedding + Kernels::IncMultiHeadAttention::apply_scaling_and_rotary( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); - int num_tokens = bc->num_active_tokens(); - hipMemcpyAsync(output_ptr, - m->attn_heads, - m->oProjSize * num_tokens * sizeof(DT), - hipMemcpyDeviceToDevice, - stream); + // phase 2: No need to update key/val cache + compute_attention_kernel_fused
(m, bc, output_ptr, stream); } } // namespace TreeIncMultiHeadAttention