Skip to content

Conversation

@zhanzy178
Copy link

@zhanzy178 zhanzy178 commented Dec 1, 2025

What this PR does / why we need it?

  • Vetorize the loop (but change not output) in some rejectsampler functions include: expand_pytorch, sample_recovered_tokens_pytorch, rejection_random_sample_pytorch, sample_recovered_tokens.
  • Remove synchronize-launch torchnpu operator in them to accelerate sampling + MTP postprocess.

Does this PR introduce any user-facing change?

  • No

How was this patch tested?

  • We tested this change with the serve&bench command:
===== serve =====
vllm serve $LOCAL_CKPT_DIR \
        --host 0.0.0.0 \
        --port 8000 \
        --data-parallel-size 4 \
        --data-parallel-size-local 2 \
        --data-parallel-address $MASTER_NODE_IP \
        --data-parallel-start-rank $((2*VC_TASK_INDEX)) \
        --data-parallel-rpc-port 13387 \
        --tensor-parallel-size 8 \
        --seed 1024 \
        --enable-expert-parallel \
        --served-model-name $NAME \
        --max-model-len 4096 \
        --max-num-seqs 16 \
        --trust-remote-code \
        --gpu-memory-utilization 0.90 \
        $headless \
	    --speculative_config '{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
        --additional-config '{"ascend_scheduler_config":{"enabled":false, "enable_chunked_prefill":true, "chunked_prefill_enabled":true}}' 

==== bench =====
vllm bench serve --model $LOCAL_CKPT_DIR  --served-model-name DeepseekV3ForCausalLM \
--dataset-name spec_bench --spec-bench-output-len 2048 \
--dataset-path question.jsonl \
--top-p 1.0 --temperature 0.8 \
--ignore-eos \
--num-prompts 64  --trust-remote-code --base-url "http://0.0.0.0:8000" --request-rate 64
  • In this case, our rj optimization can reduce TPOT from 84.94ms to 64.61ms, about 23% gain.

before

image

after

image

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors several functions in the rejection sampler to replace Python loops with vectorized PyTorch operations. The goal is to improve performance on NPUs by enabling asynchronous execution, and the changes in rejection_random_sample_pytorch, expand_pytorch, and sample_recovered_tokens_pytorch are substantial and align with this objective. My review identified a critical bug in expand_pytorch involving reading from uninitialized memory, as well as a type mismatch issue. I also found a potential robustness issue in rejection_random_sample_pytorch related to division by zero. I have provided code suggestions to address these problems.

Comment on lines 534 to 540
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()

output_slice = start_idx + offset[mask]
output_ptr[output_slice] = src_val
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)

needs_update = in_range.any(dim=1)

output_ptr[:] = torch.where(needs_update, token_values, output_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code has two issues:

  1. Critical Bug: The line output_ptr[:] = torch.where(needs_update, token_values, output_ptr) reads from output_ptr when needs_update is false. Since output_ptr is an uninitialized tensor (created with new_empty), this constitutes reading from garbage memory, which is undefined behavior.
  2. Type Mismatch: replaced_input is cast to float using .float(), causing token_values to become a float tensor. However, output_ptr is expected to be an integer tensor, which could lead to incorrect results or errors.

The logic can be simplified to fix both issues by removing the .float() cast, the needs_update mask, and the torch.where call. The einsum operation will correctly produce zeros for tokens that are not part of any request, so a direct assignment to output_ptr is sufficient, safer, and more efficient.

Suggested change
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()
output_slice = start_idx + offset[mask]
output_ptr[output_slice] = src_val
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
needs_update = in_range.any(dim=1)
output_ptr[:] = torch.where(needs_update, token_values, output_ptr)
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr)
token_values = torch.einsum("tb,b->t", in_range.to(replaced_input.dtype), replaced_input)
output_ptr[:] = token_values.to(output_ptr.dtype)

Comment on lines 449 to 451
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The division target_token_probs / draft_token_probs can result in inf or NaN if draft_token_probs contains zeros. While the logical AND & with draft_token_probs > zero_threshold might correctly handle this due to short-circuiting behavior, this approach is not fully robust and can be error-prone under different compiler optimizations or hardware. It is safer to avoid division by zero by rewriting the expression using multiplication. This improves code robustness and clarity, similar to how division by zero is handled in sample_recovered_tokens_pytorch.

Suggested change
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs
)
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs >= uniform_token_probs * draft_token_probs
)

@zhanzy178 zhanzy178 changed the title Refactor some rejectsampler functions to make npu op launch async. Optimize some rejectsampler functions to make npu op launch async. Dec 1, 2025
@github-actions
Copy link

github-actions bot commented Dec 1, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@zhanzy178 zhanzy178 changed the title Optimize some rejectsampler functions to make npu op launch async. Optimize some rejectsampler functions to make npu op launch non-blocking Dec 2, 2025
@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Dec 2, 2025
Signed-off-by: ZongYuan Zhan <[email protected]>
target_token_probs / draft_token_probs >= uniform_token_probs
)

first_rejection = (~acceptance_condition) & valid_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

400-500 lines of code without comments. It is recommended to add comments for core functionalities to improve code readability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants