Skip to content

Commit eb4c08f

Browse files
[bugfix] fix mtp accept rate (#5093)
### What this PR does / why we need it? 1. now, npu_model_runner reuses gpu_model_runner, this pr deletes some attrs already defined in gpu_model_runner 2. fix mtp accept rate by disabling in_profile_run 3. remove redundant moe method selection logic 4. Reverts #5082, which broke CI in https://github.com/vllm-project/vllm-ascend/actions/runs/20266314048/job/58190426832?pr=5088 ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: v0.12.0 vLLM main: vllm-project/vllm@ad32e3e vLLM version: v0.12.0 vLLM main: vllm-project/vllm@ad32e3e - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: zhenwenqi2024 <[email protected]> Signed-off-by: Mengqing Cao <[email protected]> Co-authored-by: Mengqing Cao <[email protected]>
1 parent 5b1da4e commit eb4c08f

File tree

5 files changed

+9
-35
lines changed

5 files changed

+9
-35
lines changed

csrc/matmul_allreduce_add_rmsnorm/op_host/aclnn_matmul_allreduce_add_rmsnorm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ enum NnopbaseHcclServerType {
2626
};
2727
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
2828

29-
#ifdef __cplusplus
30-
extern "C" {
31-
#endif
32-
3329
extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnormGetWorkspaceSize(
3430
const aclTensor *x1,
3531
const aclTensor *x2,
@@ -52,6 +48,10 @@ extern aclnnStatus aclnnInnerMatmulAllreduceAddRmsnorm(
5248
aclOpExecutor *executor,
5349
aclrtStream stream);
5450

51+
#ifdef __cplusplus
52+
extern "C" {
53+
#endif
54+
5555
aclnnStatus aclnnMatmulAllreduceAddRmsnormGetWorkspaceSize(
5656
const aclTensor *x1,
5757
const aclTensor *x2,

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def set_ascend_forward_context(
6464
get_moe_comm_method
6565
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
6666
# TODO: remove this after moe_comm_type selection logic is finalized
67-
if in_profile_run and is_mtp_model:
67+
if is_mtp_model:
6868
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
6969
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
7070
forward_context.moe_comm_type = moe_comm_type
@@ -298,8 +298,6 @@ def select_moe_comm_method(num_tokens: int,
298298
if fused_all2all_enable else MoECommType.ALLTOALL)
299299
else:
300300
raise ValueError(f"Unsupported soc_version: {soc_version}")
301-
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
302-
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
303301
# PanguProMoE only supports allgather
304302
if model_type == "PanguProMoE":
305303
moe_comm_type = MoECommType.ALLGATHER

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def dummy_run(self,
145145
dummy_compute_logits=lambda hidden_states: None):
146146
with set_ascend_forward_context(None,
147147
self.vllm_config,
148-
in_profile_run=True,
149148
num_tokens=num_tokens):
150149
self.model(
151150
input_ids=self.input_ids[:num_tokens],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ def dummy_run(self,
293293
self.vllm_config,
294294
num_tokens=num_tokens,
295295
with_prefill=with_prefill,
296-
in_profile_run=True,
297296
num_tokens_across_dp=num_tokens_across_dp,
298297
num_actual_tokens=0,
299298
aclgraph_runtime_mode=aclgraph_runtime_mode,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
244244
self.need_accepted_tokens: bool = False
245245

246246
self.is_multimodal_model = self.model_config.is_multimodal_model
247-
self.is_pooling_model = self.model_config.pooler_config is not None
248-
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
249247
self.block_size = vllm_config.cache_config.block_size
250248
# Set up Attention
251249
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
@@ -338,24 +336,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
338336
ascend_config = get_ascend_config()
339337
self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
340338
self.eplb_process, self.process)
341-
342-
self.use_async_scheduling = self.scheduler_config.async_scheduling
343-
self.async_output_copy_stream = torch.npu.Stream() if \
344-
self.use_async_scheduling else None
345-
self.num_spec_tokens = 0
346-
if self.speculative_config:
347-
self.num_spec_tokens = self.speculative_config.num_speculative_tokens # noqa
348-
self.valid_sampled_token_count_event: torch.npu.Event | None = None
349-
self.valid_sampled_token_count_copy_stream: torch.npu.Stream | None = None
350-
if self.use_async_scheduling and self.num_spec_tokens:
351-
self.valid_sampled_token_count_event = torch.npu.Event()
352-
self.valid_sampled_token_count_copy_stream = torch.npu.Stream()
353-
self.valid_sampled_token_count_cpu = torch.empty(
354-
self.max_num_reqs,
355-
dtype=torch.int64,
356-
device="cpu",
357-
pin_memory=self.pin_memory,
358-
)
359339
# Input Batch
360340
# NOTE(Chen): Ideally, we should initialize the input batch inside
361341
# `initialize_kv_cache` based on the kv cache config. However, as in
@@ -386,23 +366,20 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
386366
cp_kv_cache_interleave_size=self.parallel_config.
387367
cp_kv_cache_interleave_size,
388368
)
389-
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
390-
dtype=torch.int64)
391369
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
392370
dtype=torch.int32)
371+
# here we use int32
393372
self.sampled_token_ids_pinned_cpu = torch.empty(
394373
(self.max_num_reqs, 1),
395374
dtype=torch.int32,
396375
device="cpu",
397376
pin_memory=self.pin_memory,
398377
)
378+
# for cleancode , actually the three attrs is defined in gpu_model_runner
379+
self.execute_model_state: ExecuteModelState | None = None
399380
# None in the first PP rank. The rest are set after load_model.
400-
# the attr below is in gpu_modelrunner, but occurs lint so add them here
401381
self.intermediate_tensors: IntermediateTensors | None = None
402-
self.execute_model_state: ExecuteModelState | None = None
403382
self.reorder_batch_threshold: int | None = None
404-
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
405-
dtype=torch.int32)
406383

407384
def _init_device_properties(self) -> None:
408385
self.num_sms = None
@@ -3395,6 +3372,7 @@ def __init__(self, *args, **kwargs) -> None:
33953372

33963373
try:
33973374
# replace cuda APIs with xpu APIs, this should work by default
3375+
torch.Event = torch.npu.Event
33983376
torch.cuda.Event = torch.npu.Event
33993377
torch.cuda.Stream = torch.npu.Stream
34003378
torch.cuda.default_stream = torch.npu.default_stream

0 commit comments

Comments
 (0)