Skip to content

Conversation

@yuxingcyx
Copy link

@yuxingcyx yuxingcyx commented Dec 9, 2025

What this PR does / why we need it?

This PR introduces optimized Triton implementations for the rejection_greedy_sample_kernel and expand_kernel, delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations.

Does this PR introduce any user-facing change?

Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels:

  • rejection_greedy_sample_kernel is enhanced with rejection_greedy_sample_spec_len_1_triton and rejection_greedy_sample_triton implementations

  • expand_kernel receives a performance-optimized Triton version

These changes provide substantial performance improvements while maintaining backward compatibility

@github-actions
Copy link

github-actions bot commented Dec 9, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new Triton kernels for rejection sampling to improve performance, and the benchmarks provided show significant speedups. However, my review has identified several critical issues in the new Triton kernel implementations. The kernels currently use anti-patterns such as serializing work within a thread block by looping over elements and using data-dependent loop bounds. These patterns can lead to correctness issues and severely limit performance. Refactoring these kernels to use proper vectorization with masking is highly recommended to achieve optimal performance and ensure correctness. Additionally, there is duplicated code for calculating grid and block sizes and the use of magic numbers, which should be addressed to improve maintainability.

Comment on lines 619 to 628
for pos in tl.arange(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This loop iterates over BLOCK_SIZE and processes elements serially using tl.get_element. This is a Triton anti-pattern that prevents vectorization and causes thread divergence, which significantly degrades performance on a GPU. The logic can be fully vectorized using masks for much better efficiency. The bonus_renew_1 helper function would also become unnecessary.

    accepted_mask = (draft_token_id == target_argmax_id) & mask
    # The bonus_renew_1 function can be inlined here for vectorization.
    # Load bonus tokens only for accepted lanes.
    bonus_token_id = tl.load(bonus_token_ids_ptr + offset, mask=accepted_mask)
    # Store bonus tokens only for accepted lanes.
    tl.store(output_token_ids_ptr + offset * 2 + 1, bonus_token_id, mask=accepted_mask)

Comment on lines +787 to +872
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This loop for i in tl.range(0, BLOCK_SIZE) serializes the processing of elements within a thread block by using tl.get_element. This is a Triton anti-pattern that underutilizes the GPU's parallel processing capabilities. While vectorizing a scatter operation with variable-sized segments is non-trivial, the current implementation is highly inefficient. A more performant, vectorized approach should be used to avoid this serialization.

Comment on lines 154 to 238
vec_len = batch_size
n = cu_num_draft_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= 40:
grid = 40
BLOCK_SIZE = triton.next_power_of_2(n // grid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for calculating grid and BLOCK_SIZE contains a hardcoded 'magic number' 40 (line 158), which hurts readability and maintainability. It is also duplicated later in this file for expand_kernel (lines 294-300).

To improve the code, I suggest:

  1. Defining 40 as a named constant with a comment explaining its origin (e.g., TRITON_GRID_SIZE = 40 # Empirically tuned value...).
  2. Extracting this entire calculation into a helper function to avoid code duplication and ensure consistency.

This will make the code cleaner and easier to tune in the future.

@yuxingcyx
Copy link
Author

How was this patch tested?

Performance Benchmarking (vs origin Triton implementations)

rejection_greedy_sample_spec_len_1_triton

Batch Size MTP Triton old  (μs) Triton new  (μs)
2048 1 152.787 17.278
1024 1 77.089 10.786
512 1 39.209 8.302
256 1 20.07 6.847
128 1 10.918 5.974
64  1 6.343 4.373
32 1 3.614 2.991
8 1 1.913 1.845
1 1 2.776 2.101

rejection_greedy_sample_triton

Batch Size MTP Triton old  (μs) Triton new  (μs)
2048 2 150.746 17.57
1024 2 76.854 11.102
512 2 38.973 7.66
256 2 20.092 6.06
128 2 10.614 5.16
64 2 6.294 5.418
32 2 3.659 3.007
8 2 1.933 2.03
1 2 2.774 2.203

expand_kernel

Batch Size MTP Triton old  (μs) Triton new  (μs)
2048 1 153.107 18.099
1024 1 76.195 11.172
512 1 39.148 7.882
256 1 19.676 6.863
128 1 10.639 6.196
64 1 6.155 4.912
32 1 3.615 3.325
8 1 2.283 2.212
1 1 2.688 2.162
2048 2 152.795 17.016
1024 2 75.96 10.567
512 2 39.191 7.909
256 2 19.957 6.211
128 2 10.51 5.541
64 2 6.101 4.765
32 2 3.679 3.524
8 2 2.25 2.123
1 2 2.683 2.222

Accuracy Testing:

The new Triton implementations have passed comprehensive accuracy tests, ensuring full functional equivalence with the original Triton kernels while delivering superior performance.

@whx-sjtu whx-sjtu added ready read for review ready-for-test start test by label for PR labels Dec 9, 2025
Copy link
Collaborator

@whx-sjtu whx-sjtu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gao12312
Copy link

How to use tirton for test ?

(Worker_TP6_EP6 pid=432498) ERROR 12-10 10:47:59 [multiproc_executor.py:822] for pos in tl.arange(0, BLOCK_SIZE):
(Worker_TP7_EP7 pid=432957) ERROR 12-10 10:47:59 [multiproc_executor.py:822] RuntimeError('Only range and static_range iterators are currently supported')

@yuxingcyx yuxingcyx force-pushed the triton-cyx-new branch 5 times, most recently from c64f377 to a60e4c8 Compare December 11, 2025 03:32
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Comment on lines +771 to +777
if not rejected:
bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment to state why you encapsulate this method.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For certain inputs, each request does not go through the 'if not rejected' branch. After optimizing and collecting profiling data, it was found that this branch still accounted for a significant portion of MTE3 transfer time. Removing this branch does not affect accuracy and improves performance by 40%. Considering that this branch may still be executed for other inputs, these steps were written into a new operator, which is then called within the operator. Experiments show that performance improves to varying degrees for different inputs. After operator optimization, the setup is as follows: the part on the right represents the internal calls of the 'rejection_greedy_sample_kernel' operator to the left operator 'bonus_renew'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants