-
Notifications
You must be signed in to change notification settings - Fork 644
[feat] Add causal_conv1d_update triton kernel #4307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[feat] Add causal_conv1d_update triton kernel #4307
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Triton kernel _causal_conv1d_update_kernel_no_cache_len_no_mtp for a specific case of causal 1D convolution, and integrates it into the existing causal_conv1d_update_npu function. While the addition is a good optimization, I've found a critical bug in the implementation of the new kernel where it fails to correctly write outputs for sequences longer than one token, causing data loss. My review includes a specific comment with a code suggestion to fix this issue.
| tl.store( | ||
| out_ptr | ||
| + pid * out_batch_stride | ||
| + (doffs + tl.arange(0, DIM_BLOCK)) * out_len, | ||
| result, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a critical bug in this loop. The tl.store operation writes to a memory location that does not depend on the loop variable i. As a result, for sequences with seq_len > 1, the output for each token will be written to the same location, overwriting the previous one. Only the result for the last token (i = seq_len - 1) will be preserved, leading to incorrect convolution output. To fix this, you need to include the token index i in the output pointer calculation to ensure each token's result is stored in its correct position.
| tl.store( | |
| out_ptr | |
| + pid * out_batch_stride | |
| + (doffs + tl.arange(0, DIM_BLOCK)) * out_len, | |
| result, | |
| ) | |
| tl.store( | |
| out_ptr | |
| + pid * out_batch_stride | |
| + (doffs + tl.arange(0, DIM_BLOCK)) * out_len + i, | |
| result, | |
| ) |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: Ascendyh <[email protected]>
e6c44b3 to
6d1ffa3
Compare
What this PR does / why we need it?
This PR introduces a new Triton kernel _causal_conv1d_update_kernel_no_cache_len_no_mtp to support efficient causal 1D convolution updates in Qwen3-next. The kernel is integrated into causal_conv1d_update_npu, enabling better performance on Ascend NPU hardware.
Does this PR introduce any user-facing change?
It belongs to optimization at the internal implementation level, which helps improve the final model inference performance but does not change the user experience at the API level or in terms of invocation methods.
How was this patch tested?