Skip to content

Commit d08401d

Browse files
authored
[Main][Bugfix]Avoid using the fusion operator in the MOE model (#3834)
### What this PR does / why we need it? The current MatmulReduceScatter operator experiences performance degradation in small-shape scenarios, so it determines whether to use this operator by judging the size of the shape. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@releases/v0.11.1 --------- Signed-off-by: ZYang6263 <[email protected]>
1 parent 90ae114 commit d08401d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,16 @@ def set_ascend_forward_context(
113113
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
114114
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
115115
# the performance may degrade due to the switching of communication methods.
116+
mmrs_fusion = True
116117
if is_moe_model(vllm_config):
117118
sp_enabled = enable_sp(vllm_config) and \
118119
tp_world_size > 1 and num_tokens is not None
120+
mmrs_fusion = False
119121
else:
120122
sp_enabled = enable_sp(vllm_config) and \
121123
tp_world_size > 1 and \
122124
num_tokens is not None and num_tokens > 1000
125+
forward_context.mmrs_fusion = mmrs_fusion
123126

124127
if sp_enabled:
125128
pad_size = (tp_world_size -

vllm_ascend/ops/linear_op.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,10 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor,
382382
try:
383383
forward_context = get_forward_context()
384384
sp_enabled = forward_context.sp_enabled
385+
mmrs_fusion = forward_context.mmrs_fusion
385386
except AssertionError:
386387
sp_enabled = False
388+
mmrs_fusion = False
387389

388390
x = input_parallel
389391

@@ -409,8 +411,9 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor,
409411
quant_per_tensor)
410412

411413
# For unquant
412-
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
413-
) and torch.version.cann.startswith("8.3"):
414+
if mmrs_fusion and isinstance(
415+
self.layer.quant_method, UnquantizedLinearMethod
416+
) and torch.version.cann.startswith("8.3"):
414417
output = torch_npu.npu_mm_reduce_scatter_base(
415418
x,
416419
self.layer.weight.t(),
@@ -423,10 +426,11 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor,
423426
if bias_ is not None:
424427
output.add_(bias_)
425428
# For w8a8 quant
426-
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
427-
and isinstance(self.layer.quant_method.quant_method,
428-
AscendW8A8LinearMethod)
429-
) and torch.version.cann.startswith("8.3"):
429+
elif mmrs_fusion and (
430+
isinstance(self.layer.quant_method, AscendLinearMethod)
431+
and isinstance(self.layer.quant_method.quant_method,
432+
AscendW8A8LinearMethod)
433+
) and torch.version.cann.startswith("8.3"):
430434
if x.dtype != torch.int8:
431435
x_quant = quant_per_tensor(
432436
x, self.layer.aclnn_input_scale_reciprocal,

0 commit comments

Comments
 (0)