Skip to content

Commit 8a9862f

Browse files
authored
Merge pull request #4 from slyalin/optimum_models
Support for optimum-intel models
2 parents ce80498 + d6a323a commit 8a9862f

File tree

2 files changed

+332
-72
lines changed

2 files changed

+332
-72
lines changed

vllm/model_executor/layers/sampler.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A layer that samples the next tokens from the model's outputs."""
22
from typing import Dict, List, Optional, Tuple
33

4+
import time
45
import torch
56
import torch.nn as nn
67

@@ -37,13 +38,20 @@ def forward(
3738
hidden_states: torch.Tensor,
3839
sampling_metadata: SamplingMetadata,
3940
embedding_bias: Optional[torch.Tensor] = None,
41+
logits: Optional[torch.Tensor] = None,
4042
) -> Optional[SamplerOutput]:
41-
# Get the hidden states that we use for sampling.
42-
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
43-
44-
# Get the logits for the next tokens.
45-
logits = _get_logits(hidden_states, embedding, embedding_bias,
46-
self.vocab_size)
43+
if logits is None:
44+
# Get the hidden states that we use for sampling.
45+
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
46+
47+
start = time.time()
48+
# Get the logits for the next tokens.
49+
logits = _get_logits(hidden_states, embedding, embedding_bias,
50+
self.vocab_size)
51+
end = time.time()
52+
print(f'Out-of-model logits calculation (MatMul) took {(end - start)*1000} ms')
53+
else:
54+
logits = _prune_hidden_states(logits, sampling_metadata)
4755

4856
# Only perform sampling in the driver worker.
4957
# Note: `_get_logits` is still distributed across TP workers because

0 commit comments

Comments
 (0)