-
-
Notifications
You must be signed in to change notification settings - Fork 12k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
latest vllm version
🐛 Describe the bug
In vllm/distributed/device_communicators/base_device_communicator/base_device_communicator.py:reduce_scatter
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous() # [my question]: should be movedim(dim, 0)?
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size,) + input_tensor.shape[1:]
output_tensor = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(
output_tensor, input_tensor, group=self.device_group
)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()The first movedim should be dim to 0 not 0 to dim! Same issue appears in vllm/distributed/device_communicators/base_device_communicator/cuda_communicator.py:reduce_scatter, reduce_scatterv
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working