|
1 | 1 | """A layer that samples the next tokens from the model's outputs."""
|
2 | 2 | from typing import Dict, List, Optional, Tuple
|
3 | 3 |
|
| 4 | +import time |
4 | 5 | import torch
|
5 | 6 | import torch.nn as nn
|
6 | 7 |
|
@@ -37,13 +38,20 @@ def forward(
|
37 | 38 | hidden_states: torch.Tensor,
|
38 | 39 | sampling_metadata: SamplingMetadata,
|
39 | 40 | embedding_bias: Optional[torch.Tensor] = None,
|
| 41 | + logits: Optional[torch.Tensor] = None, |
40 | 42 | ) -> Optional[SamplerOutput]:
|
41 |
| - # Get the hidden states that we use for sampling. |
42 |
| - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
43 |
| - |
44 |
| - # Get the logits for the next tokens. |
45 |
| - logits = _get_logits(hidden_states, embedding, embedding_bias, |
46 |
| - self.vocab_size) |
| 43 | + if logits is None: |
| 44 | + # Get the hidden states that we use for sampling. |
| 45 | + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
| 46 | + |
| 47 | + start = time.time() |
| 48 | + # Get the logits for the next tokens. |
| 49 | + logits = _get_logits(hidden_states, embedding, embedding_bias, |
| 50 | + self.vocab_size) |
| 51 | + end = time.time() |
| 52 | + print(f'Out-of-model logits calculation (MatMul) took {(end - start)*1000} ms') |
| 53 | + else: |
| 54 | + logits = _prune_hidden_states(logits, sampling_metadata) |
47 | 55 |
|
48 | 56 | # Only perform sampling in the driver worker.
|
49 | 57 | # Note: `_get_logits` is still distributed across TP workers because
|
|
0 commit comments