-
Notifications
You must be signed in to change notification settings - Fork 629
[Ops][Triton] Add a triton kernel supporting partial rope. #4413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Triton kernel for Rotary Positional Embedding (RoPE) to support partial RoPE, where the RoPE dimension is not equal to the head dimension. This is a valuable performance optimization as it avoids explicit split and concat operations. The implementation includes a new Triton kernel _triton_rope and a wrapper function rope_forward_triton. The changes in sfa_v1.py correctly use this new kernel when Triton is available, with a fallback to the existing implementation. The Triton kernel itself appears to be well-written, handling both NEOX and non-NEOX styles, and correctly deals with padding and masking for variable dimensions. The upcasting to float32 for intermediate computations is a good practice for maintaining precision. I have one comment regarding a docstring that could be improved for clarity. Overall, the changes are logical and well-structured.
a3225c4 to
2d30141
Compare
b07e1a4 to
77c3431
Compare
Signed-off-by: whx-sjtu <[email protected]>
816d029 to
b0f7e4a
Compare
Signed-off-by: whx-sjtu <[email protected]>
…ect#4413) ### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <[email protected]>
…ect#4413) ### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <[email protected]> Signed-off-by: Che Ruan <[email protected]>
…ect#4413) ### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <[email protected]> Signed-off-by: Che Ruan <[email protected]>
…ect#4413) ### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <[email protected]>

What this PR does / why we need it?
This PR adds a triton rope kernel witch supports scenarios of
rope_dim != head_dim. This can save the split op before rope and the concat op after rope. Profiling shows improvement.Original Implementation(2 split+2 rope+ 2 slice +2 concat):


Because currently we only support piecewise aclgraph for DS 3.2, so there are plenty of free bubbles. You can see the computing time of all rope related kernels: 35us
New Triton Rope: 12us

Does this PR introduce any user-facing change?
None
How was this patch tested?
I will add related ut after ci integrated with triton.