-
Notifications
You must be signed in to change notification settings - Fork 624
Optimize some rejectsampler functions to make npu op launch non-blocking #4587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: ZongYuan Zhan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors several functions in the rejection sampler to replace Python loops with vectorized PyTorch operations. The goal is to improve performance on NPUs by enabling asynchronous execution, and the changes in rejection_random_sample_pytorch, expand_pytorch, and sample_recovered_tokens_pytorch are substantial and align with this objective. My review identified a critical bug in expand_pytorch involving reading from uninitialized memory, as well as a type mismatch issue. I also found a potential robustness issue in rejection_random_sample_pytorch related to division by zero. I have provided code suggestions to address these problems.
| replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float() | ||
|
|
||
| output_slice = start_idx + offset[mask] | ||
| output_ptr[output_slice] = src_val | ||
| token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) | ||
|
|
||
| needs_update = in_range.any(dim=1) | ||
|
|
||
| output_ptr[:] = torch.where(needs_update, token_values, output_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code has two issues:
- Critical Bug: The line
output_ptr[:] = torch.where(needs_update, token_values, output_ptr)reads fromoutput_ptrwhenneeds_updateis false. Sinceoutput_ptris an uninitialized tensor (created withnew_empty), this constitutes reading from garbage memory, which is undefined behavior. - Type Mismatch:
replaced_inputis cast to float using.float(), causingtoken_valuesto become a float tensor. However,output_ptris expected to be an integer tensor, which could lead to incorrect results or errors.
The logic can be simplified to fix both issues by removing the .float() cast, the needs_update mask, and the torch.where call. The einsum operation will correctly produce zeros for tokens that are not part of any request, so a direct assignment to output_ptr is sufficient, safer, and more efficient.
| replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float() | |
| output_slice = start_idx + offset[mask] | |
| output_ptr[output_slice] = src_val | |
| token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) | |
| needs_update = in_range.any(dim=1) | |
| output_ptr[:] = torch.where(needs_update, token_values, output_ptr) | |
| replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr) | |
| token_values = torch.einsum("tb,b->t", in_range.to(replaced_input.dtype), replaced_input) | |
| output_ptr[:] = token_values.to(output_ptr.dtype) |
| acceptance_condition = (draft_token_probs > zero_threshold) & ( | ||
| target_token_probs / draft_token_probs >= uniform_token_probs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The division target_token_probs / draft_token_probs can result in inf or NaN if draft_token_probs contains zeros. While the logical AND & with draft_token_probs > zero_threshold might correctly handle this due to short-circuiting behavior, this approach is not fully robust and can be error-prone under different compiler optimizations or hardware. It is safer to avoid division by zero by rewriting the expression using multiplication. This improves code robustness and clarity, similar to how division by zero is handled in sample_recovered_tokens_pytorch.
| acceptance_condition = (draft_token_probs > zero_threshold) & ( | |
| target_token_probs / draft_token_probs >= uniform_token_probs | |
| ) | |
| acceptance_condition = (draft_token_probs > zero_threshold) & ( | |
| target_token_probs >= uniform_token_probs * draft_token_probs | |
| ) |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
Signed-off-by: <>
Signed-off-by: ZongYuan Zhan <[email protected]>
Signed-off-by: ZongYuan Zhan <[email protected]>
| target_token_probs / draft_token_probs >= uniform_token_probs | ||
| ) | ||
|
|
||
| first_rejection = (~acceptance_condition) & valid_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
400-500 lines of code without comments. It is recommended to add comments for core functionalities to improve code readability.
What this PR does / why we need it?
expand_pytorch,sample_recovered_tokens_pytorch,rejection_random_sample_pytorch,sample_recovered_tokens.Does this PR introduce any user-facing change?
How was this patch tested?
before
after