-
Notifications
You must be signed in to change notification settings - Fork 655
feat: implement high-performance Triton kernels for rejection sampling #4830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces new Triton kernels for rejection sampling to improve performance, and the benchmarks provided show significant speedups. However, my review has identified several critical issues in the new Triton kernel implementations. The kernels currently use anti-patterns such as serializing work within a thread block by looping over elements and using data-dependent loop bounds. These patterns can lead to correctness issues and severely limit performance. Refactoring these kernels to use proper vectorization with masking is highly recommended to achieve optimal performance and ensure correctness. Additionally, there is duplicated code for calculating grid and block sizes and the use of magic numbers, which should be addressed to improve maintainability.
| for pos in tl.arange(0, BLOCK_SIZE): | ||
| draft_token_id1 = tl.get_element(draft_token_id, (pos, )) | ||
| target_argmax1 = tl.get_element(target_argmax_id, (pos, )) | ||
| position = block_idx * BLOCK_SIZE + pos | ||
| if draft_token_id1 == target_argmax1: | ||
| bonus_renew_1( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop iterates over BLOCK_SIZE and processes elements serially using tl.get_element. This is a Triton anti-pattern that prevents vectorization and causes thread divergence, which significantly degrades performance on a GPU. The logic can be fully vectorized using masks for much better efficiency. The bonus_renew_1 helper function would also become unnecessary.
accepted_mask = (draft_token_id == target_argmax_id) & mask
# The bonus_renew_1 function can be inlined here for vectorization.
# Load bonus tokens only for accepted lanes.
bonus_token_id = tl.load(bonus_token_ids_ptr + offset, mask=accepted_mask)
# Store bonus tokens only for accepted lanes.
tl.store(output_token_ids_ptr + offset * 2 + 1, bonus_token_id, mask=accepted_mask)| for i in tl.range(0, BLOCK_SIZE): | ||
| num_tokens1 = tl.get_element(num_tokens, (i, )) | ||
| start_idx1 = tl.get_element(start_idx, (i, )) | ||
| src_val1 = tl.get_element(src_val, (i, )) | ||
| offset1 = tl.arange(0, MAX_NUM_TOKENS) | ||
| tl.store(output_ptr + start_idx1 + offset1, | ||
| src_val1, | ||
| mask=offset1 < num_tokens1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop for i in tl.range(0, BLOCK_SIZE) serializes the processing of elements within a thread block by using tl.get_element. This is a Triton anti-pattern that underutilizes the GPU's parallel processing capabilities. While vectorizing a scatter operation with variable-sized segments is non-trivial, the current implementation is highly inefficient. A more performant, vectorized approach should be used to avoid this serialization.
| vec_len = batch_size | ||
| n = cu_num_draft_tokens.numel() | ||
| BLOCK_SIZE = 2 | ||
| grid = triton.cdiv(n, BLOCK_SIZE) | ||
| if n >= 40: | ||
| grid = 40 | ||
| BLOCK_SIZE = triton.next_power_of_2(n // grid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for calculating grid and BLOCK_SIZE contains a hardcoded 'magic number' 40 (line 158), which hurts readability and maintainability. It is also duplicated later in this file for expand_kernel (lines 294-300).
To improve the code, I suggest:
- Defining
40as a named constant with a comment explaining its origin (e.g.,TRITON_GRID_SIZE = 40 # Empirically tuned value...). - Extracting this entire calculation into a helper function to avoid code duplication and ensure consistency.
This will make the code cleaner and easier to tune in the future.
fc08aa7 to
ec703ca
Compare
How was this patch tested?Performance Benchmarking (vs origin Triton implementations)rejection_greedy_sample_spec_len_1_triton
rejection_greedy_sample_triton
expand_kernel
Accuracy Testing:The new Triton implementations have passed comprehensive accuracy tests, ensuring full functional equivalence with the original Triton kernels while delivering superior performance. |
whx-sjtu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8795d22 to
242f084
Compare
|
How to use tirton for test ? (Worker_TP6_EP6 pid=432498) ERROR 12-10 10:47:59 [multiproc_executor.py:822] for pos in tl.arange(0, BLOCK_SIZE): |
c64f377 to
a60e4c8
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
5b05aef to
5afce44
Compare
Signed-off-by: yuxingcyx <[email protected]>
5afce44 to
353facf
Compare
Signed-off-by: yuxingcyx <[email protected]>
97208b3 to
66f61e1
Compare
Signed-off-by: yuxingcyx <[email protected]>
Signed-off-by: yuxingcyx <[email protected]>
| if not rejected: | ||
| bonus_renew( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| max_spec_len, | ||
| num_tokens1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment to state why you encapsulate this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For certain inputs, each request does not go through the 'if not rejected' branch. After optimizing and collecting profiling data, it was found that this branch still accounted for a significant portion of MTE3 transfer time. Removing this branch does not affect accuracy and improves performance by 40%. Considering that this branch may still be executed for other inputs, these steps were written into a new operator, which is then called within the operator. Experiments show that performance improves to varying degrees for different inputs. After operator optimization, the setup is as follows: the part on the right represents the internal calls of the 'rejection_greedy_sample_kernel' operator to the left operator 'bonus_renew'.
What this PR does / why we need it?
This PR introduces optimized Triton implementations for the rejection_greedy_sample_kernel and expand_kernel, delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations.
Does this PR introduce any user-facing change?
Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels:
rejection_greedy_sample_kernel is enhanced with rejection_greedy_sample_spec_len_1_triton and rejection_greedy_sample_triton implementations
expand_kernel receives a performance-optimized Triton version
These changes provide substantial performance improvements while maintaining backward compatibility