Skip to content

Commit beebf47

Browse files
lsy323mgoin
andauthored
[TPU][Profiler] Support start_profile/stop_profile in TPU worker (vllm-project#13988)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent f89978a commit beebf47

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

requirements-tpu.txt

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ ray[default]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
21-
torch==2.7.0.dev20250226+cpu
22-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

vllm/v1/worker/tpu_worker.py

+19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.distributed
88
import torch.nn as nn
99
import torch_xla.core.xla_model as xm
10+
import torch_xla.debug.profiler as xp
1011
import torch_xla.runtime as xr
1112

1213
import vllm.envs as envs
@@ -65,6 +66,15 @@ def __init__(
6566
from vllm.utils import init_cached_hf_modules
6667
init_cached_hf_modules()
6768

69+
self.profiler = None
70+
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
71+
# For TPU, we can only have 1 active profiler session for 1 profiler
72+
# server. So we only profile on rank0.
73+
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
74+
logger.info("Profiling enabled. Traces will be saved to: %s",
75+
self.profile_dir)
76+
self.profiler = xp.start_server(9012)
77+
6878
def init_device(self):
6979
os.environ["PJRT_DEVICE"] = "TPU"
7080
torch.set_grad_enabled(False)
@@ -152,6 +162,15 @@ def execute_model(
152162
output = self.model_runner.execute_model(scheduler_output)
153163
return output if self.is_driver_worker else None
154164

165+
def profile(self, is_start: bool = True):
166+
if self.rank < 1:
167+
if self.profiler is None:
168+
raise RuntimeError("Profiler is not enabled.")
169+
if is_start:
170+
xp.start_trace(self.profile_dir)
171+
else:
172+
xp.stop_trace()
173+
155174
def load_model(self) -> None:
156175
self.model_runner.load_model()
157176

vllm/worker/tpu_worker.py

+22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import torch_xla.core.xla_model as xm
8+
import torch_xla.debug.profiler as xp
89
import torch_xla.runtime as xr
910

1011
import vllm.envs as envs
@@ -93,6 +94,27 @@ def init_device(self) -> None:
9394
f"tp{world_size}_rank{rank}")
9495
xr.initialize_cache(per_rank_path, readonly=False)
9596

97+
self.profiler = None
98+
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
99+
# For TPU, we can only have 1 active profiler session for 1 profiler
100+
# server. So we only profile on rank0.
101+
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
102+
logger.info("Profiling enabled. Traces will be saved to: %s",
103+
self.profile_dir)
104+
self.profiler = xp.start_server(9012)
105+
106+
def start_profile(self):
107+
if self.rank < 1:
108+
if self.profiler is None:
109+
raise RuntimeError("Profiler is not enabled.")
110+
xp.start_trace(self.profile_dir)
111+
112+
def stop_profile(self):
113+
if self.rank < 1:
114+
if self.profiler is None:
115+
raise RuntimeError("Profiler is not enabled.")
116+
xp.stop_trace()
117+
96118
def load_model(self):
97119
self.model_runner.load_model()
98120

0 commit comments

Comments
 (0)