Skip to content

Commit 55809c9

Browse files
authored
Support mla (vllm-project#775)
1 parent baf04c8 commit 55809c9

File tree

7 files changed

+94
-88
lines changed

7 files changed

+94
-88
lines changed

benchmarks/benchmark_serving.py

+1
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ async def benchmark(
549549
# multi-modal benchmark is only available on OpenAI Chat backend.
550550
raise ValueError(
551551
"Multi-modal content is only supported on 'openai-chat' backend.")
552+
test_output_len = 10
552553
test_input = RequestFuncInput(
553554
model=model_id,
554555
model_name=model_name,

scripts/run_example_tp.py

+43-41
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import argparse
44
import os
55

6-
#model_path = "/software/data/DeepSeek-R1/"
7-
model_path = "deepseek-ai/DeepSeek-V2-Lite"
6+
model_path = "/data/models/DeepSeek-R1/"
7+
# model_path = "deepseek-ai/DeepSeek-V2-Lite"
88

99
# Parse the command-line arguments.
1010
parser = argparse.ArgumentParser()
@@ -13,51 +13,53 @@
1313
parser.add_argument("--tokenizer", type=str, default=model_path, help="The model path.")
1414
#parser.add_argument("--model", type=str, default="/data/models/DeepSeek-R1-bf16-small/", help="The model path.")
1515
#parser.add_argument("--tokenizer", type=str, default="opensourcerelease/DeepSeek-R1-bf16", help="The model path.")
16-
parser.add_argument("--tp_size", type=int, default=1, help="The number of threads.")
16+
parser.add_argument("--tp_size", type=int, default=8, help="The number of threads.")
1717
args = parser.parse_args()
1818

1919
os.environ["VLLM_SKIP_WARMUP"] = "true"
2020
os.environ["HABANA_VISIBLE_DEVICES"] = "ALL"
2121
os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
22-
os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "1"
23-
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
22+
# os.environ["VLLM_RAY_DISABLE_LOG_TO_DRIVER"] = "1"
23+
# os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
2424
os.environ["VLLM_MOE_N_SLICE"] = "8"
25+
os.environ["VLLM_MLA_DISABLE_REQUANTIZATION"] = "1"
2526

27+
if __name__ == "__main__":
2628

27-
# Sample prompts.
28-
prompts = [
29-
"Hello, my name is",
30-
"The president of the United States is",
31-
"The capital of France is",
32-
"The future of AI is",
33-
]
34-
# Create a sampling params object.
35-
sampling_params = SamplingParams(temperature=0, max_tokens=50)
36-
model = args.model
37-
if args.tp_size == 1:
38-
llm = LLM(
39-
model=model,
40-
tokenizer=args.tokenizer,
41-
trust_remote_code=True,
42-
dtype="bfloat16",
43-
max_model_len=1024,
44-
)
45-
else:
46-
llm = LLM(
47-
model=model,
48-
tokenizer=args.tokenizer,
49-
tensor_parallel_size=args.tp_size,
50-
distributed_executor_backend='ray',
51-
trust_remote_code=True,
52-
max_model_len=1024,
53-
dtype="bfloat16",
54-
)
29+
# Sample prompts.
30+
prompts = [
31+
"Hello, my name is",
32+
"The president of the United States is",
33+
"The capital of France is",
34+
"The future of AI is",
35+
]
36+
# Create a sampling params object.
37+
sampling_params = SamplingParams(temperature=0, max_tokens=50)
38+
model = args.model
39+
if args.tp_size == 1:
40+
llm = LLM(
41+
model=model,
42+
tokenizer=args.tokenizer,
43+
trust_remote_code=True,
44+
dtype="bfloat16",
45+
max_model_len=1024,
46+
)
47+
else:
48+
llm = LLM(
49+
model=model,
50+
tokenizer=args.tokenizer,
51+
tensor_parallel_size=args.tp_size,
52+
distributed_executor_backend='mp',
53+
trust_remote_code=True,
54+
max_model_len=1024,
55+
dtype="bfloat16",
56+
)
5557

56-
# Generate texts from the prompts. The output is a list of RequestOutput objects
57-
# that contain the prompt, generated text, and other information.
58-
outputs = llm.generate(prompts, sampling_params)
59-
# Print the outputs.
60-
for output in outputs:
61-
prompt = output.prompt
62-
generated_text = output.outputs[0].text
63-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
58+
# Generate texts from the prompts. The output is a list of RequestOutput objects
59+
# that contain the prompt, generated text, and other information.
60+
outputs = llm.generate(prompts, sampling_params)
61+
# Print the outputs.
62+
for output in outputs:
63+
prompt = output.prompt
64+
generated_text = output.outputs[0].text
65+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

scripts/run_static-online.sh

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22
tp_parrallel=8
3-
bs=32
3+
bs=96
44
in_len=1024
55
out_len=1024
66
multi_step=1
@@ -10,12 +10,13 @@ VLLM_DECODE_BLOCK_BUCKET_MAX=$((total_len * bs / 128 + 128))
1010

1111
# model="/data/models/DeepSeek-R1/"
1212
# tokenizer="/data/models/DeepSeek-R1/"
13-
model="/software/data/DeepSeek-R1/"
14-
tokenizer="/software/data/DeepSeek-R1/"
13+
model="/data/models/DeepSeek-R1/"
14+
tokenizer="/data/models/DeepSeek-R1/"
1515
model_name="DeepSeek-R1"
1616

1717
HABANA_VISIBLE_DEVICES="ALL" \
18-
VLLM_MOE_N_SLICE=8 \
18+
VLLM_MOE_N_SLICE=4 \
19+
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
1920
PT_HPU_ENABLE_LAZY_COLLECTIVES="true" \
2021
VLLM_RAY_DISABLE_LOG_TO_DRIVER="1" \
2122
RAY_IGNORE_UNHANDLED_ERRORS="1" \
@@ -37,7 +38,6 @@ python -m vllm.entrypoints.openai.api_server \
3738
--use-v2-block-manager \
3839
--num_scheduler_steps ${multi_step}\
3940
--max-model-len 2048 \
40-
--max-num-batched-tokens 2048 \
4141
--distributed_executor_backend ray \
4242
--gpu_memory_utilization 0.9 \
4343
--trust_remote_code 2>&1 | tee benchmark_logs/serving.log &
@@ -53,7 +53,7 @@ done
5353
sleep 5s
5454
echo ${pid}
5555

56-
num_prompts=32
56+
num_prompts=300
5757
request_rate=1
5858
start_time=$(date +%s)
5959
echo "Start to benchmark"

vllm/attention/backends/hpu_attn.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ def get_kv_cache_shape(
7373
num_blocks: int,
7474
block_size: int,
7575
num_kv_heads: int,
76-
kv_lora_rank: int,
76+
head_size: int,
7777
) -> Tuple[int, ...]:
78-
k_pe_size = kv_lora_rank // 8
79-
return (num_blocks, block_size, kv_lora_rank + k_pe_size), True
78+
return (num_blocks, block_size, head_size), (num_blocks, block_size, head_size//9*8)
8079

8180
@staticmethod
8281
def get_impl_cls() -> Type["HPUAttentionImpl"]:
@@ -137,7 +136,8 @@ def __init__(
137136
self.matmul_av = Matmul()
138137
self.batch2block_matmul = Matmul()
139138
self.block2batch_matmul = Matmul()
140-
self.latent_cache = VLLMKVCache()
139+
self.latent_cache_k = VLLMKVCache()
140+
self.latent_cache_v = VLLMKVCache()
141141
HPUFusedSDPA = kernels.fsdpa()
142142
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
143143
else ModuleFusedSDPA(HPUFusedSDPA)
@@ -186,9 +186,6 @@ def forward(
186186
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
187187
.view(-1, self.num_heads, self.qk_rope_head_dim)
188188
input_positions = attn_metadata.input_positions.view(-1)
189-
print("q_pe", q_pe.shape)
190-
print("k_pe", k_pe.shape)
191-
print("input_positions", attn_metadata.input_positions.shape)
192189
q_pe, k_pe = \
193190
self.rotary_emb(input_positions, q_pe, k_pe)
194191
else:
@@ -197,9 +194,6 @@ def forward(
197194

198195
q_pe = q[..., self.qk_nope_head_dim:]
199196

200-
# print("q_pe shape", q_pe.shape)
201-
# print("k_pe shape", k_pe.shape)
202-
# print("input_positions shape", attn_metadata.input_positions.shape)
203197
input_positions = attn_metadata.input_positions.view(-1)
204198
# TODO(lucas): there must be a nicer way to write this line
205199
q[..., self.qk_nope_head_dim:], k_pe = \
@@ -208,15 +202,29 @@ def forward(
208202
block_indices = attn_metadata.block_indices
209203
block_offsets = attn_metadata.block_offsets
210204

211-
latent_vec = torch.concat(
205+
latent_vec_k = torch.concat(
212206
(k_c_normed, k_pe.view(batch_size, -1, self.qk_rope_head_dim)), dim=-1)
213207
# assert layer._k_scale == 0, f"got _k_scale={layer._k_scale}"
214-
# print(f"layer._k_scale={layer._k_scale}")
208+
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
209+
latent_vec_v = k_c_normed.view(-1, self.kv_lora_rank)
210+
if is_prefill:
211+
latent_vec_k = latent_vec_k.unflatten(0, (block_indices.size(0), -1))
212+
latent_vec_v = latent_vec_v.unflatten(0, (block_indices.size(0), -1))
213+
# print("latent_vec", latent_vec.shape)
214+
215215

216216
# write the latent and rope to kv cache
217-
if kv_cache is not None:
218-
kv_cache = self.latent_cache(latent_vec, kv_cache, block_indices,
217+
if kv_cache is not None and len(kv_cache) == 2:
218+
# print(f"k cache shape: {kv_cache[0].shape}")
219+
# print(f"v cache shape: {kv_cache[1].shape}")
220+
# print(f"latent vec k shape: {latent_vec_k.shape}")
221+
# print(f"latent vec v shape: {latent_vec_v.shape}")
222+
223+
k_cache = self.latent_cache_k(latent_vec_k, kv_cache[0], block_indices,
219224
block_offsets)
225+
v_cache = self.latent_cache_v(latent_vec_v, kv_cache[1], block_indices,
226+
block_offsets)
227+
kv_cache = (k_cache, v_cache)
220228

221229
if is_prefill:
222230
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata, batch_size)
@@ -268,20 +276,14 @@ def _forward_decode(
268276
self,
269277
q_nope: torch.Tensor,
270278
q_pe: torch.Tensor,
271-
kv_c_and_k_pe_cache: torch.Tensor,
279+
kv_cache: torch.Tensor,
272280
attn_metadata: HPUAttentionMetadata,
273281
batch_size: int
274282
) -> torch.Tensor:
275-
print(f"q_nope shape: {q_nope.shape}")
276-
print(f"q_pe shape: {q_pe.shape}")
277-
278283
q = torch.cat([q_nope, q_pe], dim=-1)
279-
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
280-
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
284+
kv_c_and_k_pe_cache = kv_cache[0].unsqueeze(2)
285+
kv_c_cache = kv_cache[1].unsqueeze(2)
281286

282-
print(f"q shape: {q.shape}")
283-
print(f"kv_c_and_k_pe_cache shape: {kv_c_and_k_pe_cache.shape}")
284-
print(f"kv_c_cache shape: {kv_c_cache.shape}")
285287
output = HPUPagedAttention.forward_decode(
286288
query=q,
287289
key_cache=kv_c_and_k_pe_cache,
@@ -296,13 +298,11 @@ def _forward_decode(
296298
matmul_av_op=self.matmul_av,
297299
batch2block_matmul_op=self.batch2block_matmul,
298300
block2batch_matmul_op=self.block2batch_matmul,
299-
keys_fetch_func=self.latent_cache.fetch_from_cache,
300-
values_fetch_func=self.latent_cache.fetch_from_cache)
301+
keys_fetch_func=self.latent_cache_k.fetch_from_cache,
302+
values_fetch_func=self.latent_cache_v.fetch_from_cache)
301303
output = output.view(batch_size, 1, -1)
302-
print("output", output.shape)
303304
result = self._v_up_proj_and_o_proj(output)
304305
result = result.view(batch_size, 1, -1)
305-
print("result", result.shape)
306306
return result
307307

308308

vllm/worker/cache_engine.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def get_cache_block_size(
112112
key_cache_block = cache_config.block_size * num_heads * head_size
113113
# For MLA there is no value cache, since the latent vector
114114
# is joint keys and values.
115-
value_cache_block = key_cache_block if not model_config.use_mla else 0
115+
# value_cache_block = key_cache_block if not model_config.use_mla else 0
116+
value_cache_block = key_cache_block // 9 * 8
116117
total = num_attention_layers * (key_cache_block + value_cache_block)
117118
if cache_config.cache_dtype == "auto":
118119
dtype = model_config.dtype

vllm/worker/hpu_worker.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -568,25 +568,24 @@ def _allocate_kv_cache(
568568
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
569569

570570
use_mla = False
571-
if len(kv_cache_shape) == 2 and kv_cache_shape[1]:
571+
if len(kv_cache_shape) == 2:
572572
use_mla = True
573-
kv_cache_shape = kv_cache_shape[0]
573+
k_cache_shape = kv_cache_shape[0]
574+
v_cache_shape = kv_cache_shape[1]
575+
else:
576+
k_cache_shape = kv_cache_shape
577+
v_cache_shape = kv_cache_shape
574578

575579
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
576580
dtype = self.dtype
577581
if device != 'hpu' and not is_fake_hpu() \
578582
and self.dtype == torch.float8_e4m3fn:
579583
dtype = torch.uint8
580584
for _ in range(self.num_attention_layers):
581-
if use_mla:
582-
kv_layer = torch.zeros(kv_cache_shape,
583-
dtype=dtype,
584-
device=device)
585-
else:
586-
key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=device)
587-
value_cache = torch.zeros(kv_cache_shape,
588-
dtype=dtype,
589-
device=device)
590-
kv_layer = (key_cache, value_cache)
585+
key_cache = torch.zeros(k_cache_shape, dtype=dtype, device=device)
586+
value_cache = torch.zeros(v_cache_shape,
587+
dtype=dtype,
588+
device=device)
589+
kv_layer = (key_cache, value_cache)
591590
kv_cache.append(kv_layer)
592591
return kv_cache

vllm/worker/model_runner_base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def _init_attn_metadata_from_tensor_dict(
4848
valid_attn_kwargs = {}
4949
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
5050
if field.name in tensor_dict:
51-
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
51+
if field.name == "input_positions":
52+
valid_attn_kwargs[field.name] = tensor_dict[field.name]
53+
else:
54+
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
5255

5356
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
5457
tensor_dict["attn_metadata"] = attn_metadata

0 commit comments

Comments
 (0)