Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexllm #83

Draft
wants to merge 78 commits into
base: inference
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
e3b7e79
first commit
goliaro Nov 22, 2024
bd93585
rename
goliaro Nov 23, 2024
00d1a1b
helpers
goliaro Nov 23, 2024
82284f2
backup
goliaro Nov 23, 2024
6266ec7
backup
goliaro Nov 24, 2024
4fcdf79
finished impl
goliaro Nov 25, 2024
6d7aa57
add benchmarking scripts
goliaro Nov 25, 2024
09aa5af
fix
goliaro Nov 25, 2024
9c0d827
fix
goliaro Nov 25, 2024
784c8d9
build fix
goliaro Nov 25, 2024
ee48def
fix bugs
goliaro Nov 26, 2024
450c98f
update
goliaro Nov 27, 2024
98e025c
fix
goliaro Nov 27, 2024
d963933
fixes
goliaro Nov 27, 2024
287dadb
fix
goliaro Nov 27, 2024
6a1544d
update
goliaro Nov 28, 2024
10e1596
use one bc only
goliaro Nov 29, 2024
89a6287
fix
goliaro Nov 29, 2024
caa4880
fix
goliaro Nov 29, 2024
f27a224
fix
goliaro Nov 30, 2024
68dfa8c
fix
goliaro Nov 30, 2024
529d9f1
fix
goliaro Nov 30, 2024
6cb9266
fix
goliaro Nov 30, 2024
03498c1
fix
goliaro Nov 30, 2024
9fcf6a3
update
goliaro Nov 30, 2024
7fba224
Merge branch 'inference' into flexllm
goliaro Dec 1, 2024
07e8e62
add file to get dolly
goliaro Dec 1, 2024
bbde408
update
goliaro Dec 1, 2024
de2fcd9
update
goliaro Dec 1, 2024
c8dc468
add warmup requests
goliaro Dec 2, 2024
0b150e3
update
goliaro Dec 3, 2024
4ec561e
update
goliaro Dec 3, 2024
07b9e50
fix
goliaro Dec 4, 2024
b077d27
update
goliaro Dec 4, 2024
3939025
add scripts
goliaro Dec 4, 2024
943c06b
fix
goliaro Dec 4, 2024
9516e59
gqa fwd
goliaro Dec 5, 2024
bf1ff34
bwd gqa
goliaro Dec 5, 2024
b3a393a
fix
goliaro Dec 6, 2024
04c6857
fix
goliaro Dec 6, 2024
ec500d2
add overhead test
goliaro Dec 6, 2024
16b3a99
update script
goliaro Dec 7, 2024
c65cf0d
update script
goliaro Dec 7, 2024
bc30870
fixes
goliaro Dec 7, 2024
2cc42d1
fix
goliaro Dec 8, 2024
b094c12
add plot script
goliaro Dec 8, 2024
2d7910b
fixes, add data
goliaro Dec 8, 2024
ab0b209
update
goliaro Dec 8, 2024
066d9e8
add results and plots
goliaro Dec 9, 2024
5cc5535
update
goliaro Dec 10, 2024
770b5a9
add new files
goliaro Dec 10, 2024
8cddc22
update
goliaro Dec 10, 2024
97678e8
update
goliaro Jan 10, 2025
c38a70f
update
goliaro Jan 10, 2025
a33837b
update
goliaro Jan 10, 2025
3c7bbc4
add flashinfer
goliaro Jan 10, 2025
f3b7cab
update
goliaro Jan 10, 2025
302d470
add flashinfer w recompute
goliaro Jan 14, 2025
3ab9b4e
fix
goliaro Jan 15, 2025
8f52547
Merge branch 'inference' into flexllm
goliaro Feb 5, 2025
87b5403
remove exp data
goliaro Feb 7, 2025
85227c0
fix gqa, comment out flashinfer
goliaro Feb 7, 2025
e4d8c2c
Merge branch 'inference' into flexllm
goliaro Feb 9, 2025
6f26214
fix
goliaro Feb 9, 2025
32b526a
fix
goliaro Feb 10, 2025
2a7b1d0
fix
goliaro Feb 10, 2025
2176c6b
fix
goliaro Feb 10, 2025
b2d8c8b
fix test
goliaro Feb 10, 2025
8de1a14
format
goliaro Feb 10, 2025
64d6275
cleanup
goliaro Feb 10, 2025
a0117e0
clean
goliaro Feb 10, 2025
f8d3886
update
goliaro Feb 17, 2025
8f2d8dd
Merge branch 'inference' into flexllm
goliaro Feb 17, 2025
6d38983
update
goliaro Feb 17, 2025
6256118
fix
goliaro Feb 18, 2025
40a0a53
update
goliaro Feb 21, 2025
a575b0a
Merge branch 'inference' into flexllm
goliaro Feb 23, 2025
20488db
cleanup
goliaro Feb 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@
[submodule "deps/tokenizers-cpp"]
path = deps/tokenizers-cpp
url = https://github.com/mlc-ai/tokenizers-cpp.git
fetchRecurseSubmodules = true
fetchRecurseSubmodules = true
[submodule "deps/flashinfer"]
path = deps/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ include(variant)
# optional
include(optional)

# flashinfer
list(APPEND FLEXFLOW_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/deps/flashinfer/include)

if (FF_GPU_BACKEND STREQUAL "cuda")
list(APPEND FF_CC_FLAGS
-DFF_USE_CUDA)
Expand Down
1 change: 1 addition & 0 deletions deps/flashinfer
Submodule flashinfer added at be6bf5
2 changes: 1 addition & 1 deletion deps/legion
203 changes: 203 additions & 0 deletions include/flexflow/attention_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef _FLEXFLOW_ATTENTION_CONFIG_H_
#define _FLEXFLOW_ATTENTION_CONFIG_H_
#include "flexflow/batch_config.h"

namespace FlexFlow {

constexpr uint32_t kPagesize = 64;

inline int round_up_pages(int const num_elements) {
return (num_elements + kPagesize - 1) / kPagesize;
}

#define DISPATCH_HEADDIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
constexpr size_t HEAD_DIM = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 256: { \
constexpr size_t HEAD_DIM = 256; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported head_dim: " << head_dim; \
throw std::invalid_argument(err_msg.str()); \
} \
}

class AttentionMetaData {
public:
AttentionMetaData() {
num_q_heads_ = 0;
num_kv_heads_ = 0;
head_dim_ = 0;
q_indptr = nullptr;
kv_indptr = nullptr;
kv_indices = nullptr;
kv_last_page_len = nullptr;
qk_indptr = nullptr;
custom_mask = nullptr;
workspace = nullptr;
workspace_size = 0;
float_workspace = nullptr;
float_workspace_size = 0;
int_workspace = nullptr;
int_workspace_size = 0;
mem_size_ = 0;
enabled_ = false;
}
AttentionMetaData(AttentionMetaData const &rhs) {
num_q_heads_ = rhs.num_q_heads_;
num_kv_heads_ = rhs.num_kv_heads_;
head_dim_ = rhs.head_dim_;
q_indptr = rhs.q_indptr;
kv_indptr = rhs.kv_indptr;
kv_indices = rhs.kv_indices;
kv_last_page_len = rhs.kv_last_page_len;
qk_indptr = rhs.qk_indptr;
custom_mask = rhs.custom_mask;
workspace = rhs.workspace;
workspace_size = rhs.workspace_size;
float_workspace = rhs.float_workspace;
float_workspace_size = rhs.float_workspace_size;
int_workspace = rhs.int_workspace;
int_workspace_size = rhs.int_workspace_size;
mem_size_ = rhs.mem_size_;
enabled_ = rhs.enabled_;
decode_handler_collections = rhs.decode_handler_collections;
prompt_handler_collections = rhs.prompt_handler_collections;
}

size_t mem_size() {
if (mem_size_ > 0) {
return mem_size_;
}
size_t batch_size = BatchConfig::max_requests_per_batch();
size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = 0;

float_workspace_size = 128 * 1024 * 1024; // 128 MB
int_workspace_size = 8 * 1024 * 1024; // 8 MB
workspace_size =
float_workspace_size + int_workspace_size; // float + int workspace

mem_size_ = alignTo(sizeof(int32_t) * indices_size +
sizeof(uint8_t) * custom_mask_size + workspace_size,
16);
return mem_size_;
}

void assign_address(void *ptr, int size) {
if (ptr == nullptr) {
q_indptr = nullptr;
kv_indptr = nullptr;
kv_indices = nullptr;
kv_last_page_len = nullptr;
qk_indptr = nullptr;
custom_mask = nullptr;
workspace = nullptr;
float_workspace = nullptr;
int_workspace = nullptr;
return;
}
assert(size >= mem_size() &&
"Insufficient memory size for attention metadata");
size_t batch_size = BatchConfig::max_requests_per_batch();
size_t max_num_pages = round_up_pages(BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = 0;

q_indptr = static_cast<int32_t *>(ptr);
kv_indptr = q_indptr + batch_size + 1;
kv_indices = kv_indptr + batch_size + 1;
kv_last_page_len = kv_indices + max_num_pages * batch_size;
qk_indptr = kv_last_page_len + batch_size + 1;
custom_mask = static_cast<uint8_t *>(ptr) + sizeof(int32_t) * indices_size;
workspace = static_cast<void *>(static_cast<uint8_t *>(ptr) +
sizeof(int32_t) * indices_size +
sizeof(uint8_t) * custom_mask_size);
float_workspace = workspace;
int_workspace = static_cast<void *>(static_cast<uint8_t *>(workspace) +
float_workspace_size);
}

void set_num_q_heads(uint32_t const num_q_heads) {
num_q_heads_ = num_q_heads;
}
void set_num_kv_heads(uint32_t const num_kv_heads) {
num_kv_heads_ = num_kv_heads;
}
void set_head_dim(uint32_t const head_dim) {
head_dim_ = head_dim;
}
uint32_t num_q_heads() const {
return num_q_heads_;
}
uint32_t num_kv_heads() const {
return num_kv_heads_;
}
uint32_t head_dim() const {
return head_dim_;
}

void set_enabled(bool const enabled) {
enabled_ = enabled;
}
bool enabled() const {
return enabled_;
}

uint32_t num_q_heads_;
uint32_t num_kv_heads_;
uint32_t head_dim_;

int32_t *q_indptr;
int32_t *kv_indptr;
int32_t *kv_indices;
int32_t *kv_last_page_len;
int32_t *qk_indptr;
uint8_t *custom_mask;
void *workspace;
size_t workspace_size;
void *float_workspace;
size_t float_workspace_size;
void *int_workspace;
size_t int_workspace_size;

size_t mem_size_;

// batchsize -> handler
bool enabled_;
std::unordered_map<int, void *> decode_handler_collections;
std::unordered_map<int, void *> prompt_handler_collections;
};
} // namespace FlexFlow

#endif // _FLEXFLOW_ATTENTION_CONFIG_H_
6 changes: 4 additions & 2 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

namespace FlexFlow {

inline int alignTo(int x, int y) {
return ((x + y - 1) / y) * y;
}

class InferenceResult;
class BeamInferenceResult;

Expand Down Expand Up @@ -84,8 +88,6 @@ class BatchConfig {
static int const MAX_SPEC_TREE_TOKEN_NUM = 64;
static int const MAX_PEFT_CONFIG_SIZE = 1024;

// Set by update

int num_tokens = 0, num_generation_tokens = 0;

struct PerRequestInfo {
Expand Down
16 changes: 16 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
#define _FLEXFLOW_CONFIG_H_
#include "ffconst.h"
#include "flexflow/batch_config.h"
#ifdef USE_FLASHINFER
#include "flexflow/attention_config.h"
#include "flexflow/ops/kernels/gemm_impl.h"
#endif
#include "legion.h"
#include <cstring>
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#ifdef USE_FLASHINFER
#include <cublasLt.h>
#endif
#include <cublas_v2.h>
#include <cudnn.h>
#elif defined(FF_USE_HIP_ROCM)
Expand Down Expand Up @@ -89,13 +96,22 @@ struct FFHandler {
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnHandle_t dnn;
cublasHandle_t blas;
#ifdef USE_FLASHINFER
cublasLtHandle_t blasLt;
// Internal::GemmEngine *gemm_engine;
#endif
#else
miopenHandle_t dnn;
hipblasHandle_t blas;
#endif
void *workSpace;
size_t workSpaceSize;
CombinedBatchConfigMetaStruct *batch_config_metadata;
#ifdef USE_FLASHINFER
AttentionMetaData *incr_attention_metadata;
AttentionMetaData *tree_search_attention_metadata;
AttentionMetaData *tree_verify_attention_metadata;
#endif

// request info + token info + topolopgy mask info
size_t batch_config_metadata_size = sizeof(CombinedBatchConfigMetaStruct);
Expand Down
6 changes: 6 additions & 0 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
bool *position_bias;
float scaling_factor;
void *devQKVProjArray, *keyCache, *valueCache;
void *queryTmp, *outputTmp;
void *qk_prods, *qk_prods_softmax;
void *attn_heads;
BatchConfig::PerTokenInfo *token_infos;
Expand All @@ -197,6 +198,11 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
// typedef hipFloatComplex attFloatComplex;
hipFloatComplex *complex_input;
#endif
// GQA
void **d_A_array, **d_B_array, **d_C_array;
void **d_A_array2, **d_B_array2, **d_C_array2;
size_t gqa_ptr_array_size;

// PEFT specific fields
void *softmax_activation_buffer;
void *query_activation_buffer;
Expand Down
Loading
Loading