Skip to content

Commit e4acb2d

Browse files
authored
[feat] support customized and separated hccl_buffer_size for process group initialization (#3073)
### What this PR does / why we need it? Currently, users have to set `HCCL_BUFFSIZE` to 512~1024 to perform mc2 operators (dispatch and combine) while running moe models with large `ep_size` and `batch_size`. This environmental variable not only affects allocated VRAM for mc2 group, but also increases VRAM allocation for dp, tp & ep groups, leading to significant kvcache and free_memory drops. This PR supports to automatically calculate and set `hccl_buffer_size` for each process group **(except mc2 group)** separately when users set `HCCL_BUFFSIZE` for mc2 group. This can significantly reduce wasted buffer_size set for dp, tp & ep groups. Note that current mc2 operators can only perform communication space partitioning based on `HCCL_BUFFSIZE` configuration. Once they support `hccl_buffer_size` configuration with `pg_options` while initializing process group, we'll caculate the required buffer size and users would avoid set `HCCL_BUFFSIZE` themselves. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We performed E2E serving with deepseek_r1 initializing DP/TP/EP/MC2 process group and observed significant kv_cache and free_memory increase! - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: linfeng-yuan <[email protected]>
1 parent 9eb1036 commit e4acb2d

File tree

4 files changed

+143
-6
lines changed

4 files changed

+143
-6
lines changed

tests/ut/patch/worker/patch_common/test_patch_distributed.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setUp(self):
2929
self.mock_group_ranks = [[0, 1]]
3030
self.mock_local_rank = 0
3131
self.mock_backend = "hccl"
32-
self.mock_use_device_comm = True
32+
self.mock_use_device_comm = False
3333

3434
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
3535
patcher_new_group = patch("torch.distributed.new_group",
@@ -39,16 +39,24 @@ def setUp(self):
3939
patcher_device_comm_cls = patch(
4040
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
4141
return_value=MagicMock())
42+
patcher_calculate_dp_buffer = patch(
43+
"vllm_ascend.utils.calculate_dp_buffer_size", return_value=64)
44+
patcher_npu_current_device = patch("torch.npu.current_device",
45+
return_value=MagicMock())
4246

4347
self.mock_get_rank = patcher_get_rank.start()
4448
self.mock_new_group = patcher_new_group.start()
4549
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
4650
self.mock_resolve_obj = patcher_device_comm_cls.start()
51+
self.mock_calculate_dp_buffer = patcher_calculate_dp_buffer.start()
52+
self.mock_npu_current_device = patcher_npu_current_device.start()
4753

4854
self.addCleanup(patcher_get_rank.stop)
4955
self.addCleanup(patcher_new_group.stop)
5056
self.addCleanup(patcher_is_cuda_alike.stop)
5157
self.addCleanup(patcher_device_comm_cls.stop)
58+
self.addCleanup(patcher_calculate_dp_buffer.stop)
59+
self.addCleanup(patcher_npu_current_device.stop)
5260

5361
self.group_coordinator = GroupCoordinatorPatch(
5462
group_ranks=self.mock_group_ranks,

vllm_ascend/patch/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@
8787
# ** File: worker/patch_common/patch_distributed.py **
8888
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8989
# 1. `vllm.distributed.parallel_state.GroupCoordinator`
90+
# (1) __init__()
91+
# Why:
92+
# The original GroupCoordinator initialization lacks pg_options to generate new
93+
# process group with customized options.
94+
# How:
95+
# Inject HCCL options during process group initialization.
96+
# Related PR (if no, explain why):
97+
# Need a PR to vllm to support a dictionary as input while initializing distributed
98+
# environment (e.g., Dict[str, torch.distributed.ProcessGroupHCCL.Options])
99+
# https://github.com/vllm-project/vllm/pull/25417
100+
# Future Plan:
101+
# Remove this patch when vllm merges this PR.
102+
# (2) all_to_all()
90103
# Why:
91104
# vllm doesn't support all_to_all for GroupCoordinator.
92105
# How:

vllm_ascend/patch/worker/patch_common/patch_distributed.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,82 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import List, Optional
18+
from typing import List, Optional, Union
1919

2020
import torch
2121
import vllm
22-
from vllm.distributed.parallel_state import GroupCoordinator
22+
from torch.distributed import Backend
23+
from vllm.distributed.parallel_state import (GroupCoordinator,
24+
_get_unique_name, _register_group)
25+
26+
from vllm_ascend.distributed.communicator import NPUCommunicator
27+
from vllm_ascend.utils import create_hccl_pg_options
2328

2429

2530
class GroupCoordinatorPatch(GroupCoordinator):
2631

27-
def __init__(self, *args, **kwargs):
28-
super().__init__(*args, **kwargs)
32+
def __init__(
33+
self,
34+
group_ranks: list[list[int]],
35+
local_rank: int,
36+
torch_distributed_backend: Union[str, Backend],
37+
use_device_communicator: bool, # whether to use device communicator
38+
use_message_queue_broadcaster: bool = False,
39+
group_name: Optional[str] = None,
40+
):
41+
group_name = group_name or "anonymous"
42+
self.unique_name = _get_unique_name(group_name)
43+
_register_group(self)
44+
45+
self.rank = torch.distributed.get_rank()
46+
self.local_rank = local_rank
47+
48+
self_device_group = None
49+
self_cpu_group = None
50+
hccl_pg_options = create_hccl_pg_options(group_name)
51+
52+
for ranks in group_ranks:
53+
device_group = torch.distributed.new_group(
54+
ranks,
55+
backend=torch_distributed_backend,
56+
pg_options=hccl_pg_options)
57+
58+
# a group with `gloo` backend, to allow direct coordination between
59+
# processes through the CPU.
60+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
61+
if self.rank in ranks:
62+
self.ranks = ranks
63+
self.world_size = len(ranks)
64+
self.rank_in_group = ranks.index(self.rank)
65+
self_device_group = device_group
66+
self_cpu_group = cpu_group
67+
68+
assert self_cpu_group is not None
69+
assert self_device_group is not None
70+
71+
self.cpu_group = self_cpu_group
72+
self.device_group = self_device_group
73+
self.device = torch.npu.current_device()
74+
75+
self.use_device_communicator = use_device_communicator
76+
self.device_communicator = None
77+
if use_device_communicator and self.world_size > 1:
78+
self.device_communicator = NPUCommunicator(
79+
cpu_group=self.cpu_group,
80+
device=self.device,
81+
device_group=self.device_group,
82+
unique_name=self.unique_name,
83+
)
84+
85+
from vllm.distributed.device_communicators.shm_broadcast import \
86+
MessageQueue
87+
self.mq_broadcaster: Optional[MessageQueue] = None
88+
if use_message_queue_broadcaster and self.world_size > 1:
89+
self.mq_broadcaster = MessageQueue.create_from_process_group(
90+
self.cpu_group, 1 << 22, 6)
91+
92+
self.use_custom_op_call = False
93+
self.use_cpu_custom_send_recv = False
2994

3095
def all_to_all(self,
3196
input_: torch.Tensor,
@@ -41,9 +106,10 @@ def all_to_all(self,
41106
assert -input_.dim() <= gather_dim < input_.dim(), (
42107
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
43108
)
109+
assert self.device_communicator is not None, "device_communicator should be initialized when world_size > 1"
44110
return self.device_communicator.all_to_all(input_, scatter_dim,
45111
gather_dim, scatter_sizes,
46112
gather_sizes)
47113

48114

49-
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
115+
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch

vllm_ascend/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
_CURRENT_STREAM = None
5454
_PREFETCH_STREAM = None
5555
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
56+
_DEFAULT_BUFFER_SIZE = 200
57+
_MIN_DP_BUFFER_SIZE = 50
5658

5759

5860
def is_310p():
@@ -648,3 +650,51 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
648650
return nullcontext()
649651
assert target_stream is not None
650652
return torch.npu.stream(target_stream)
653+
654+
655+
def create_hccl_pg_options(group_name: str):
656+
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
657+
hccl_config = get_hccl_config_for_pg_options(group_name)
658+
if hccl_config is not None:
659+
options.hccl_config = hccl_config
660+
return options
661+
662+
663+
def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
664+
"""
665+
Get HCCL process group options for the given communication group name.
666+
667+
Args:
668+
group_name: Name of the communication group
669+
670+
Returns:
671+
HCCL pg_options or None for mc2 group
672+
"""
673+
# FIXME: Current mc2 operators only perform communication space partitioning
674+
# based on HCCL_BUFFSIZE configuration. Using pg_options with mc2 group would
675+
# result in memory misalignment problems.
676+
if group_name and "mc2" in group_name:
677+
return None
678+
hccl_config_map = {
679+
"dp": {
680+
"hccl_buffer_size": calculate_dp_buffer_size()
681+
},
682+
}
683+
return hccl_config_map.get(group_name, get_default_buffer_config())
684+
685+
686+
def get_default_buffer_config() -> dict:
687+
return {"hccl_buffer_size": _DEFAULT_BUFFER_SIZE}
688+
689+
690+
def calculate_dp_buffer_size() -> int:
691+
"""
692+
formula of dp buffer size:
693+
dp_size + 2 (flags: with_prefill and enable_dbo)
694+
"""
695+
from vllm.config import get_current_vllm_config
696+
vllm_config = get_current_vllm_config()
697+
dp_size = vllm_config.parallel_config.data_parallel_size
698+
int32_size = torch.iinfo(torch.int32).bits // 8
699+
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
700+
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)

0 commit comments

Comments
 (0)