Skip to content

Commit 7555a6c

Browse files
Merge commit 'refs/pull/4751/head' of https://github.com/vllm-project/vllm-ascend into eplb_mix
2 parents 60c5bb3 + 85204de commit 7555a6c

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,14 +484,15 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
484484
self._process_weights_for_fused_mlapo(act_dtype)
485485

486486
def _v_up_proj(self, x):
487-
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
488-
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
489-
x = torch_npu.npu_transpose_batchmatmul(x,
490-
self.W_UV,
491-
perm_x1=[1, 0, 2],
492-
perm_x2=[0, 1, 2],
493-
perm_y=[1, 0, 2])
494-
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
487+
if x.dtype in [torch.float16, torch.bfloat16] \
488+
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
489+
x = x.view(-1, self.num_heads, self.kv_lora_rank)
490+
b, _, _ = x.shape
491+
res = torch.empty((b, self.num_heads, self.v_head_dim),
492+
dtype=x.dtype,
493+
device=x.device)
494+
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
495+
x = res.reshape(-1, self.num_heads * self.v_head_dim)
495496
else:
496497
# Convert from (B, N, L) to (N, B, L)
497498
x = x.view(-1, self.local_num_heads,

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.v1.utils import CpuGpuBuffer
2828
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2929

30-
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
30+
from vllm_ascend.ascend_forward_context import (MoECommType,
31+
set_ascend_forward_context)
3132
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3233
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3334
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
@@ -237,6 +238,9 @@ def dummy_run(self,
237238
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
238239

239240
moe_comm_type = self.runner._select_moe_comm_method(num_tokens)
241+
# TODO: remove this after moe_comm_type selection logic is finalized
242+
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
243+
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
240244

241245
if skip_attn:
242246
attn_metadata = None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
has_kv_transfer_group)
5353
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
5454
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
55-
get_pcp_group, get_pp_group,
56-
get_tp_group,
55+
get_ep_group, get_pcp_group,
56+
get_pp_group, get_tp_group,
5757
is_global_first_rank)
5858
from vllm.forward_context import get_forward_context
5959
from vllm.logger import logger
@@ -2349,10 +2349,11 @@ def _select_moe_comm_method(self,
23492349
moe_comm_type = MoECommType.ALLGATHER
23502350

23512351
elif soc_version in {AscendDeviceType._910_93}:
2352-
moe_comm_type = (MoECommType.MC2
2353-
if num_tokens <= self.mc2_tokens_capacity else
2354-
MoECommType.FUSED_ALLTOALL if quant_type
2355-
== "w8a8_dynamic" else MoECommType.ALLTOALL)
2352+
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
2353+
moe_comm_type = (
2354+
MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else
2355+
MoECommType.FUSED_ALLTOALL if quant_type == "w8a8_dynamic"
2356+
and get_ep_group().world_size <= 16 else MoECommType.ALLTOALL)
23562357
else:
23572358
raise ValueError(f"Unsupported soc_version: {soc_version}")
23582359

0 commit comments

Comments
 (0)