From 43aabb127b70f789a3182909a2a50bbe8e43f8f5 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Mon, 11 Nov 2024 10:37:18 +0800 Subject: [PATCH] Fix --- .../device_communicators/xpu_communicator.py | 47 +++++++++++++++++++ vllm/distributed/parallel_state.py | 18 +++++++ vllm/envs.py | 2 +- 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 vllm/distributed/device_communicators/xpu_communicator.py diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000000000..ff2b8b40a809f --- /dev/null +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,47 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + + +class XpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_xpu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + dist.all_reduce(x, group=self.group) + return x + + def gather(self, + input_: torch.Tensor, + rank_in_group: int, + dst: int = 0, + dim: int = -1): + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.group) + if rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d3ac4eb78b155..ccb5723a77ab0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -166,6 +166,7 @@ def __init__( use_pynccl: bool, use_custom_allreduce: bool, use_tpu_communicator: bool, + use_xpu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -202,6 +203,7 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce self.use_tpu_communicator = use_tpu_communicator + self.use_xpu_communicator = use_xpu_communicator # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -230,6 +232,12 @@ def __init__( if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + from vllm.distributed.device_communicators.xpu_communicator import ( + XpuCommunicator) + self.xpu_communicator: Optional[XpuCommunicator] + if use_xpu_communicator and self.world_size > 1: + self.xpu_communicator = XpuCommunicator(group=self.device_group) + from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -345,6 +353,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # TPU handles Dynamo with its own logic. return self._all_reduce(input_) + if self.xpu_communicator is not None and \ + not self.xpu_communicator.disabled: + return self.xpu_communicator.all_reduce(input_) + if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name) @@ -433,6 +445,10 @@ def gather(self, if dim < 0: # Convert negative dim to positive. dim += input_.dim() + if self.xpu_communicator is not None and \ + not self.xpu_communicator.disabled: + return self.xpu_communicator.gather(input_, self.rank_in_group, + dst, dim) # Allocate output tensor. if self.rank_in_group == dst: gather_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -847,6 +863,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + use_xpu_communicator=False, group_name="world", ) @@ -868,6 +885,7 @@ def init_model_parallel_group( use_pynccl=True, use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, + use_xpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) diff --git a/vllm/envs.py b/vllm/envs.py index 705d858e71a66..eb26a8825d9df 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_RPC_TIMEOUT: int = 100000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False