Skip to content

Commit ba28d54

Browse files
AlvisGongclrs97zzhx1Kurumi5210
authored
[Perf]enable prefill flashcommon3 (#4065)
### What this PR does / why we need it? moe multistream overlap to improve the performance. ### How was this patch tested? --additional-config '{"multistream_overlap_gate": true}' - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: AlvisGong <[email protected]> Signed-off-by: chenxiao <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: zzhx1 <[email protected]> Co-authored-by: chenxiao <[email protected]>
1 parent 0686b32 commit ba28d54

File tree

8 files changed

+239
-40
lines changed

8 files changed

+239
-40
lines changed

tests/ut/ops/test_prepare_finalize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ class TestPrepareAndFinalize(unittest.TestCase):
1313

1414
def setUp(self):
1515
# Mock FusedMoEConfig
16+
fake_stream = MagicMock()
17+
patcher = patch("torch.npu.Stream", return_value=fake_stream)
18+
patcher.start()
19+
self.addCleanup(patcher.stop)
1620
self.moe_config = MagicMock(spec=FusedMoEConfig)
1721
self.moe_config.tp_group = MagicMock()
1822
self.moe_config.tp_group.device_group = MagicMock()

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(self, vllm_config):
106106
enable_shared_expert_dp=True)
107107
self.multistream_overlap_shared_expert = additional_config.get(
108108
"multistream_overlap_shared_expert", False)
109+
self.multistream_overlap_gate = additional_config.get(
110+
"multistream_overlap_gate", False)
109111
self.recompute_scheduler_enable = additional_config.get(
110112
"recompute_scheduler_enable", False)
111113
self.enable_cpu_binding = additional_config.get(

vllm_ascend/distributed/parallel_state.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
_LMTP: Optional[GroupCoordinator] = None
2121
_EMBED_TP: Optional[GroupCoordinator] = None
2222

23-
# flashcomm2 specific groups
23+
# flashcomm specific groups
2424
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
2525
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
26+
_FC3_QUANT_X: Optional[GroupCoordinator] = None
2627

2728
# shared_weight across rank groups
2829
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
@@ -241,6 +242,15 @@ def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
241242
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
242243
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
243244

245+
if get_ascend_config().multistream_overlap_gate:
246+
global _FC3_QUANT_X
247+
group_ranks = all_ranks.unbind(0)
248+
group_ranks = [x.tolist() for x in group_ranks]
249+
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
250+
get_world_group().local_rank,
251+
backend,
252+
group_name="fc3_quant_x")
253+
244254

245255
def model_parallel_initialized():
246256
return (_MC2 is not None)
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
296306
return _P_TP
297307

298308

309+
def get_fc3_quant_x_group() -> GroupCoordinator:
310+
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
311+
return _FC3_QUANT_X
312+
313+
299314
def destroy_ascend_model_parallel():
300315
global _MC2
301316
if _MC2:
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
343358
if _SHARED_WEIGHT:
344359
_SHARED_WEIGHT.destroy()
345360
_SHARED_WEIGHT = None
361+
362+
global _FC3_QUANT_X
363+
if _FC3_QUANT_X:
364+
_FC3_QUANT_X.destroy()
365+
_FC3_QUANT_X = None

vllm_ascend/distributed/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import torch
44
import torch.distributed as dist
5+
from vllm.forward_context import get_forward_context
56

6-
from vllm_ascend.distributed.parallel_state import get_p_tp_group
7+
from vllm_ascend.distributed.parallel_state import (get_dp_group,
8+
get_fc3_quant_x_group,
9+
get_p_tp_group)
710

811

912
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
@@ -59,3 +62,31 @@ def get_transfer_timeout_value():
5962
'7')) # type: ignore
6063
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
6164
3000)
65+
66+
67+
def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
68+
try:
69+
forward_context = get_forward_context()
70+
except AssertionError:
71+
return x
72+
x = get_fc3_quant_x_group().all_gather(x, 0)
73+
dp_metadata = forward_context.dp_metadata
74+
if dp_metadata is None:
75+
pad_size = forward_context.pad_size
76+
if pad_size > 0:
77+
x = x[:-pad_size]
78+
else:
79+
# unpad
80+
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
81+
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
82+
device=x.device,
83+
dtype=x.dtype)
84+
dp_size = get_dp_group().world_size
85+
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
86+
offset = 0
87+
for idx in range(dp_size):
88+
num_tokens_dp = num_tokens_across_dp_cpu[idx]
89+
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
90+
offset += num_tokens_dp
91+
x = result
92+
return x
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import torch
5+
from vllm.model_executor.layers.linear import LinearBase
6+
7+
8+
@dataclass
9+
class FlashCommon3Context:
10+
gate: Optional[LinearBase] = None
11+
topk_weights: Optional[torch.Tensor] = None
12+
topk_ids: Optional[torch.Tensor] = None
13+
row_idx: Optional[torch.Tensor] = None
14+
shared_experts: Optional[torch.nn.Module] = None
15+
shared_out: Optional[torch.Tensor] = None
16+
17+
18+
_flash_common3_context: Optional[FlashCommon3Context] = None
19+
20+
21+
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
22+
return _flash_common3_context
23+
24+
25+
def set_flash_common3_context(
26+
topk_weights: Optional[torch.Tensor] = None,
27+
topk_ids: Optional[torch.Tensor] = None,
28+
shared_experts: Optional[torch.nn.Module] = None,
29+
shared_out: Optional[torch.Tensor] = None,
30+
):
31+
global _flash_common3_context
32+
if _flash_common3_context is None:
33+
_flash_common3_context = FlashCommon3Context()
34+
35+
if topk_weights is not None:
36+
_flash_common3_context.topk_weights = topk_weights
37+
if topk_ids is not None:
38+
_flash_common3_context.topk_ids = topk_ids
39+
if shared_experts is not None:
40+
_flash_common3_context.shared_experts = shared_experts
41+
if shared_out is not None:
42+
_flash_common3_context.shared_out = shared_out

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@
3737
from vllm_ascend.distributed.parallel_state import get_mc2_group
3838
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
3939
from vllm_ascend.eplb.utils import moe_load_async_stream
40+
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
41+
set_flash_common3_context)
4042
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4143
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
42-
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
44+
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
45+
setup_moe_comm_method)
4346
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
4447
from vllm_ascend.quantization.w4a8_dynamic import \
4548
AscendW4A8DynamicFusedMoEMethod
@@ -139,6 +142,7 @@ def apply(self,
139142

140143
class AscendFusedMoE(FusedMoE):
141144
moe_counter = -1
145+
gate_stream: Optional[torch.npu.Stream] = None
142146

143147
def __init__(self, *args, **kwargs):
144148
super().__init__(*args, **kwargs)
@@ -170,6 +174,10 @@ def __init__(self, *args, **kwargs):
170174
self.expert_map_path = ascend_config.expert_map_path
171175
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
172176
self.global_num_experts = num_experts + self.global_redundant_expert_num
177+
# flashcommon3 gate stream
178+
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
179+
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
180+
AscendFusedMoE.gate_stream = torch.npu.Stream()
173181
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
174182
vllm_config = get_current_vllm_config()
175183
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
@@ -332,13 +340,58 @@ def forward_impl(self, hidden_states: torch.Tensor,
332340
enable_force_load_balance = forward_context.in_profile_run
333341

334342
forward_context = get_forward_context()
343+
if self.multistream_overlap_gate:
344+
assert AscendFusedMoE.gate_stream is not None
345+
fc3_context = get_flash_common3_context()
346+
assert fc3_context is not None
347+
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
348+
with npu_stream_switch(AscendFusedMoE.gate_stream,
349+
enabled=self.multistream_overlap_gate):
350+
# share_expert
351+
assert fc3_context.shared_experts is not None
352+
shared_out = fc3_context.shared_experts(hidden_states)
353+
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
354+
moe_comm_type = forward_context.moe_comm_type
355+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
356+
and not shared_expert_dp_enabled():
357+
shared_out = tensor_model_parallel_all_reduce(shared_out)
358+
set_flash_common3_context(shared_out=shared_out)
359+
360+
topk_weights, topk_ids = select_experts(
361+
hidden_states=hidden_states,
362+
router_logits=router_logits,
363+
top_k=self.top_k,
364+
use_grouped_topk=self.use_grouped_topk,
365+
renormalize=self.renormalize,
366+
topk_group=self.topk_group,
367+
num_expert_group=self.num_expert_group,
368+
custom_routing_function=self.custom_routing_function,
369+
scoring_func=self.scoring_func,
370+
routed_scaling_factor=self.routed_scaling_factor,
371+
e_score_correction_bias=self.e_score_correction_bias,
372+
global_num_experts=self.global_num_experts)
373+
374+
if isinstance(forward_context.moe_comm_method,
375+
AllGatherCommImpl):
376+
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
377+
topk_weights, True, True)
378+
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
379+
topk_ids, True, True)
380+
381+
set_flash_common3_context(topk_weights=topk_weights,
382+
topk_ids=topk_ids)
383+
335384
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
336385
hidden_states=hidden_states,
337386
router_logits=router_logits,
338387
replace_allreduce=forward_context.sp_enabled,
339388
enable_shared_expert_dp=self.enable_shared_expert_dp,
340389
quant_type=self.quant_type)
341390

391+
# Make sure the default stream waits for the gate stream to finish.
392+
if self.multistream_overlap_gate:
393+
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
394+
342395
if isinstance(hidden_states, tuple):
343396
hidden_states, pertoken_scale = hidden_states
344397
else:
@@ -407,6 +460,7 @@ def __init__(
407460
self.shared_expert_stream = None
408461
ascend_config = get_ascend_config()
409462
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
463+
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
410464
if enable_sp():
411465
logger.info_once(
412466
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -443,30 +497,42 @@ def forward(
443497

444498
def forward_impl(self, hidden_states: torch.Tensor,
445499
router_logits: torch.Tensor):
446-
# Make sure the shared experts stream begins after hidden_states are ready.
447-
if self.multistream_overlap_shared_expert:
448-
shared_experts_calculation_stream().wait_stream( # type: ignore
449-
torch.npu.current_stream())
450-
with npu_stream_switch(shared_experts_calculation_stream(),
451-
enabled=self.multistream_overlap_shared_expert):
452-
# Use a separate stream to run shared experts.
453-
# Note that currently we only support calculations in separate streams with aclgraph.
454-
# Communication operations in another stream might cause unknown errors.
455-
shared_out = self._shared_experts(hidden_states)
500+
shared_out = None
501+
if not self.multistream_overlap_gate:
502+
# Make sure the shared experts stream begins after hidden_states are ready.
503+
if self.multistream_overlap_shared_expert:
504+
shared_experts_calculation_stream(
505+
).wait_stream( # type: ignore
506+
torch.npu.current_stream())
507+
with npu_stream_switch(
508+
shared_experts_calculation_stream(),
509+
enabled=self.multistream_overlap_shared_expert):
510+
# Use a separate stream to run shared experts.
511+
shared_out = self._shared_experts(hidden_states)
512+
else:
513+
set_flash_common3_context(shared_experts=self._shared_experts)
456514

457515
fused_output = AscendFusedMoE.forward_impl(
458516
self,
459517
hidden_states=hidden_states,
460518
router_logits=router_logits,
461519
)
462-
# Make sure the default stream waits for the shared experts stream to finish.
463-
if self.multistream_overlap_shared_expert:
464-
torch.npu.current_stream().wait_stream(
465-
shared_experts_calculation_stream())
466-
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
467-
forward_context = get_forward_context()
468-
moe_comm_type = forward_context.moe_comm_type
469-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
470-
and not shared_expert_dp_enabled():
471-
shared_out = tensor_model_parallel_all_reduce(shared_out)
520+
521+
if not self.multistream_overlap_gate:
522+
# Make sure the default stream waits for the shared experts stream to finish.
523+
if self.multistream_overlap_shared_expert:
524+
torch.npu.current_stream().wait_stream(
525+
shared_experts_calculation_stream())
526+
527+
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
528+
forward_context = get_forward_context()
529+
moe_comm_type = forward_context.moe_comm_type
530+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
531+
and not shared_expert_dp_enabled():
532+
shared_out = tensor_model_parallel_all_reduce(shared_out)
533+
else:
534+
fc3_context = get_flash_common3_context()
535+
assert fc3_context is not None
536+
shared_out = fc3_context.shared_out
537+
472538
return shared_out, fused_output

vllm_ascend/ops/fused_moe/prepare_finalize.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
from vllm.forward_context import get_forward_context
3030
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
3131

32-
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
32+
from vllm_ascend.ascend_config import get_ascend_config
33+
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
34+
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
35+
prefill_context_parallel_enable)
3336

3437

3538
class QuantType(Enum):
@@ -49,9 +52,14 @@ class PrepareAndFinalize(ABC):
4952
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
5053
sizes, ranks, and communication settings.
5154
"""
55+
quant_stream: Optional[torch.npu.Stream] = None
5256

5357
def __init__(self, moe_config: FusedMoEConfig):
5458
self.moe_config = moe_config
59+
ascend_config = get_ascend_config()
60+
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
61+
if self.multistream_overlap_gate and PrepareAndFinalize.quant_stream is None:
62+
PrepareAndFinalize.quant_stream = torch.npu.Stream()
5563

5664
@abstractmethod
5765
def prepare(
@@ -335,12 +343,28 @@ def _prepare_with_ep_group(
335343
if quant_type == QuantType.W8A8:
336344
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
337345
hidden_states)
346+
347+
if self.multistream_overlap_gate:
348+
assert PrepareAndFinalize.quant_stream is not None
349+
PrepareAndFinalize.quant_stream.wait_stream(
350+
torch.npu.current_stream())
351+
with npu_stream_switch(PrepareAndFinalize.quant_stream,
352+
enabled=self.multistream_overlap_gate):
353+
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
354+
hidden_states)
355+
else:
356+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
357+
hidden_states, True, True)
358+
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
359+
router_logits, True, True)
360+
361+
if pertoken_scale is not None:
338362
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
339363
pertoken_scale, True, True)
340-
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
341-
hidden_states, True, True)
342-
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
343-
router_logits, True, True)
364+
365+
if self.multistream_overlap_gate:
366+
torch.npu.current_stream().wait_stream(
367+
PrepareAndFinalize.quant_stream)
344368

345369
if pertoken_scale is not None:
346370
return (hidden_states, pertoken_scale), router_logits, None, None

0 commit comments

Comments
 (0)