File tree 3 files changed +47
-5
lines changed
3 files changed +47
-5
lines changed Original file line number Diff line number Diff line change @@ -17,8 +17,9 @@ ray[default]
17
17
--find-links https://storage.googleapis.com/libtpu-releases/index.html
18
18
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
19
19
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20
-
21
- torch==2.7.0.dev20250226+cpu
22
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
23
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
24
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20
+ torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21
+ torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22
+ torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
Original file line number Diff line number Diff line change 7
7
import torch .distributed
8
8
import torch .nn as nn
9
9
import torch_xla .core .xla_model as xm
10
+ import torch_xla .debug .profiler as xp
10
11
import torch_xla .runtime as xr
11
12
12
13
import vllm .envs as envs
@@ -65,6 +66,15 @@ def __init__(
65
66
from vllm .utils import init_cached_hf_modules
66
67
init_cached_hf_modules ()
67
68
69
+ self .profiler = None
70
+ if envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 :
71
+ # For TPU, we can only have 1 active profiler session for 1 profiler
72
+ # server. So we only profile on rank0.
73
+ self .profile_dir = envs .VLLM_TORCH_PROFILER_DIR
74
+ logger .info ("Profiling enabled. Traces will be saved to: %s" ,
75
+ self .profile_dir )
76
+ self .profiler = xp .start_server (9012 )
77
+
68
78
def init_device (self ):
69
79
os .environ ["PJRT_DEVICE" ] = "TPU"
70
80
torch .set_grad_enabled (False )
@@ -152,6 +162,15 @@ def execute_model(
152
162
output = self .model_runner .execute_model (scheduler_output )
153
163
return output if self .is_driver_worker else None
154
164
165
+ def profile (self , is_start : bool = True ):
166
+ if self .rank < 1 :
167
+ if self .profiler is None :
168
+ raise RuntimeError ("Profiler is not enabled." )
169
+ if is_start :
170
+ xp .start_trace (self .profile_dir )
171
+ else :
172
+ xp .stop_trace ()
173
+
155
174
def load_model (self ) -> None :
156
175
self .model_runner .load_model ()
157
176
Original file line number Diff line number Diff line change 5
5
6
6
import torch
7
7
import torch_xla .core .xla_model as xm
8
+ import torch_xla .debug .profiler as xp
8
9
import torch_xla .runtime as xr
9
10
10
11
import vllm .envs as envs
@@ -93,6 +94,27 @@ def init_device(self) -> None:
93
94
f"tp{ world_size } _rank{ rank } " )
94
95
xr .initialize_cache (per_rank_path , readonly = False )
95
96
97
+ self .profiler = None
98
+ if envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 :
99
+ # For TPU, we can only have 1 active profiler session for 1 profiler
100
+ # server. So we only profile on rank0.
101
+ self .profile_dir = envs .VLLM_TORCH_PROFILER_DIR
102
+ logger .info ("Profiling enabled. Traces will be saved to: %s" ,
103
+ self .profile_dir )
104
+ self .profiler = xp .start_server (9012 )
105
+
106
+ def start_profile (self ):
107
+ if self .rank < 1 :
108
+ if self .profiler is None :
109
+ raise RuntimeError ("Profiler is not enabled." )
110
+ xp .start_trace (self .profile_dir )
111
+
112
+ def stop_profile (self ):
113
+ if self .rank < 1 :
114
+ if self .profiler is None :
115
+ raise RuntimeError ("Profiler is not enabled." )
116
+ xp .stop_trace ()
117
+
96
118
def load_model (self ):
97
119
self .model_runner .load_model ()
98
120
You can’t perform that action at this time.
0 commit comments