Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xpu communicator #50

Open
wants to merge 1 commit into
base: 062_test_0929
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform


class XpuCommunicator:

def __init__(self, group: ProcessGroup):
if not current_platform.is_xpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self,
input_: torch.Tensor,
rank_in_group: int,
dst: int = 0,
dim: int = -1):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((self.world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.group)
if rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor
18 changes: 18 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_xpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_xpu_communicator = use_xpu_communicator

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand Down Expand Up @@ -230,6 +232,12 @@ def __init__(
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator)
self.xpu_communicator: Optional[XpuCommunicator]
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)

from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
Expand Down Expand Up @@ -345,6 +353,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)

if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name)
Expand Down Expand Up @@ -433,6 +445,10 @@ def gather(self,
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
Expand Down Expand Up @@ -847,6 +863,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_xpu_communicator=False,
group_name="world",
)

Expand All @@ -868,6 +885,7 @@ def init_model_parallel_group(
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_xpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
Expand Down
2 changes: 1 addition & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_RPC_TIMEOUT: int = 100000 # ms
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
Expand Down